Files
mal/internal/playback/handler/playlist_rewrite_test.go

123 lines
3.2 KiB
Go

package handler
import (
"context"
"fmt"
"mal/internal/domain"
"net/http"
"strings"
"testing"
)
type rewritePlaybackService struct {
targets []string
}
func (s *rewritePlaybackService) BuildWatchData(context.Context, int, []string, string, string, string) (domain.WatchPageData, error) {
return domain.WatchPageData{}, nil
}
func (s *rewritePlaybackService) SaveProgress(context.Context, string, int64, int, float64) error {
return nil
}
func (s *rewritePlaybackService) CompleteAnime(context.Context, string, int64) error {
return nil
}
func (s *rewritePlaybackService) SignProxyToken(targetURL, _ string, _ string) (string, error) {
s.targets = append(s.targets, targetURL)
return fmt.Sprintf("token-%d", len(s.targets)), nil
}
func (s *rewritePlaybackService) ResolveProxyToken(string, string) (string, string, error) {
return "", "", nil
}
func (s *rewritePlaybackService) UpsertSkipSegmentOverride(context.Context, string, int64, int, string, float64, float64) error {
return nil
}
func TestRewriteHLSPlaylistProxiesSegmentAndKeyURIs(t *testing.T) {
svc := &rewritePlaybackService{}
h := &PlaybackHandler{svc: svc}
body := strings.Join([]string{
"#EXTM3U",
`#EXT-X-KEY:METHOD=AES-128,URI="keys/key.bin"`,
"#EXTINF:4.0,",
"segments/seg-1.ts",
"#EXTINF:4.0,",
"https://cdn.example.test/video/seg-2.ts",
"",
}, "\n")
got, err := h.rewriteHLSPlaylist(body, "https://origin.example.test/hls/master/index.m3u8", "https://referer.example.test")
if err != nil {
t.Fatalf("rewriteHLSPlaylist returned error: %v", err)
}
if strings.Contains(got, "origin.example.test") || strings.Contains(got, "cdn.example.test") || strings.Contains(got, "keys/key.bin") || strings.Contains(got, "segments/seg-1.ts") {
t.Fatalf("rewritten playlist leaked upstream data:\n%s", got)
}
for _, token := range []string{"token-1", "token-2", "token-3"} {
if !strings.Contains(got, "/watch/proxy/stream?token="+token) {
t.Fatalf("rewritten playlist missing %s:\n%s", token, got)
}
}
wantTargets := []string{
"https://origin.example.test/hls/master/keys/key.bin",
"https://origin.example.test/hls/master/segments/seg-1.ts",
"https://cdn.example.test/video/seg-2.ts",
}
if strings.Join(svc.targets, "\n") != strings.Join(wantTargets, "\n") {
t.Fatalf("targets = %#v, want %#v", svc.targets, wantTargets)
}
}
func TestIsHLSPlaylistResponse(t *testing.T) {
tests := []struct {
name string
target string
headers map[string]string
want bool
}{
{
name: "playlist url",
target: "https://example.test/master.m3u8?token=abc",
want: true,
},
{
name: "playlist content type",
target: "https://example.test/video.bin",
headers: map[string]string{
"Content-Type": "application/x-mpegurl; charset=utf-8",
},
want: true,
},
{
name: "non playlist",
target: "https://example.test/video.bin",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var headers http.Header
if len(tt.headers) > 0 {
headers = make(http.Header)
for key, value := range tt.headers {
headers.Set(key, value)
}
}
if got := isHLSPlaylistResponse(tt.target, headers); got != tt.want {
t.Fatalf("isHLSPlaylistResponse() = %v, want %v", got, tt.want)
}
})
}
}