From be5824ab5fb5bb8ac2110272331163fbf751424e Mon Sep 17 00:00:00 2001 From: mkelvers Date: Mon, 20 Apr 2026 01:26:16 +0200 Subject: [PATCH] refactor: consolidate proxy handlers and simplify ranking/watchlist --- internal/features/playback/handler.go | 65 +++------ internal/features/playback/service_ranking.go | 131 ++++++++---------- internal/features/watchlist/service.go | 48 +++---- 3 files changed, 93 insertions(+), 151 deletions(-) diff --git a/internal/features/playback/handler.go b/internal/features/playback/handler.go index 2c50836..0b349e7 100644 --- a/internal/features/playback/handler.go +++ b/internal/features/playback/handler.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "io" "log" "net/http" @@ -184,72 +185,48 @@ func convertSegments(segments []SkipSegment) []templates.SkipSegment { return result } -func (h *Handler) HandleProxyStream(w http.ResponseWriter, r *http.Request) { +func (h *Handler) HandleProxy(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return } - mode := normalizeMode(r.URL.Query().Get("mode")) - if mode == "" { - mode = "dub" - } - token := strings.TrimSpace(r.URL.Query().Get("token")) if token == "" { http.Error(w, "missing playback token", http.StatusBadRequest) return } - targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeStream) + scope := proxyScope(strings.TrimPrefix(r.URL.Path, "/watch/proxy/")) + scopeLabel := map[proxyScope]string{ + proxyScopeStream: "stream", + proxyScopeSegment: "segment", + proxyScopeSubtitle: "subtitle", + }[scope] + if scopeLabel == "" { + http.Error(w, "invalid proxy scope", http.StatusBadRequest) + return + } + + targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, scope) if err != nil { - http.Error(w, "invalid stream token", http.StatusBadRequest) + http.Error(w, fmt.Sprintf("invalid %s token", scopeLabel), http.StatusBadRequest) return } h.proxyUpstream(w, r, targetURL, referer) } +func (h *Handler) HandleProxyStream(w http.ResponseWriter, r *http.Request) { + h.HandleProxy(w, r) +} + func (h *Handler) HandleProxySegment(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - token := strings.TrimSpace(r.URL.Query().Get("token")) - if token == "" { - http.Error(w, "missing segment token", http.StatusBadRequest) - return - } - - targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeSegment) - if err != nil { - http.Error(w, "invalid segment token", http.StatusBadRequest) - return - } - - h.proxyUpstream(w, r, targetURL, referer) + h.HandleProxy(w, r) } func (h *Handler) HandleProxySubtitle(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - token := strings.TrimSpace(r.URL.Query().Get("token")) - if token == "" { - http.Error(w, "missing subtitle token", http.StatusBadRequest) - return - } - - targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeSubtitle) - if err != nil { - http.Error(w, "invalid subtitle token", http.StatusBadRequest) - return - } - - h.proxyUpstream(w, r, targetURL, referer) + h.HandleProxy(w, r) } func (h *Handler) HandleSaveProgress(w http.ResponseWriter, r *http.Request) { diff --git a/internal/features/playback/service_ranking.go b/internal/features/playback/service_ranking.go index 3e7b589..12b5477 100644 --- a/internal/features/playback/service_ranking.go +++ b/internal/features/playback/service_ranking.go @@ -30,8 +30,8 @@ func rankSources(sources []StreamSource, quality string) ([]sourceScore, error) targetQuality := normalizeQuality(quality) scored := make([]sourceScore, 0, len(filtered)) for _, source := range filtered { - typeScore := sourceTypePriority(source.Type) - providerScore := providerPriority(source.Provider) + typeScore := sourceTypePriorityFn(source.Type) + providerScore := providerPriorityFn(source.Provider) qualityScore := sourceQualityPriority(source.Quality, targetQuality) refererScore := 0 if source.Referer != "" { @@ -65,48 +65,43 @@ func normalizeQuality(quality string) string { return lower } -func sourceTypePriority(sourceType string) int { - switch strings.ToLower(sourceType) { - case "mp4": - return 500 - case "m3u8": - return 450 - case "unknown": - return 300 - case "embed": - return 100 - default: - return 200 - } +var sourceTypePriority = map[string]int{ + "mp4": 500, + "m3u8": 450, + "unknown": 300, + "embed": 100, } -func providerPriority(provider string) int { - switch strings.ToLower(provider) { - case "s-mp4": - return 120 - case "default": - return 115 - case "luf-mp4": - return 110 - case "vid-mp4": - return 105 - case "yt-mp4": - return 100 - case "mp4": - return 95 - case "uv-mp4": - return 90 - case "hls": - return 80 - case "sw": - return 40 - case "ok": - return 35 - case "ss-hls": - return 30 - default: - return 60 +var providerPriority = map[string]int{ + "s-mp4": 120, + "default": 115, + "luf-mp4": 110, + "vid-mp4": 105, + "yt-mp4": 100, + "mp4": 95, + "uv-mp4": 90, + "hls": 80, + "sw": 40, + "ok": 35, + "ss-hls": 30, +} + +var sourceQualityDefaults = map[string]int{ + "auto": 240, +} + +func sourceTypePriorityFn(sourceType string) int { + if p, ok := sourceTypePriority[strings.ToLower(sourceType)]; ok { + return p } + return 200 +} + +func providerPriorityFn(provider string) int { + if p, ok := providerPriority[strings.ToLower(provider)]; ok { + return p + } + return 60 } func sourceQualityPriority(sourceQuality string, targetQuality string) int { @@ -126,34 +121,6 @@ func sourceQualityPriority(sourceQuality string, targetQuality string) int { } } -func parseQualityValue(rawQuality string) int { - lower := strings.ToLower(rawQuality) - var digits strings.Builder - - for _, char := range lower { - if char >= '0' && char <= '9' { - digits.WriteRune(char) - continue - } - if digits.Len() > 0 { - break - } - } - - if digits.Len() > 0 { - value, err := strconv.Atoi(digits.String()) - if err == nil { - return value - } - } - - if lower == "auto" { - return 240 - } - - return 0 -} - func qualityMatches(sourceQuality string, targetQuality string) bool { sourceLower := strings.ToLower(sourceQuality) targetLower := strings.ToLower(targetQuality) @@ -166,10 +133,25 @@ func qualityMatches(sourceQuality string, targetQuality string) bool { return true } - sourceDigits := extractDigits(sourceLower) - targetDigits := extractDigits(targetLower) + return extractDigits(sourceLower) == extractDigits(targetLower) +} - return sourceDigits != "" && sourceDigits == targetDigits +func parseQualityValue(rawQuality string) int { + lower := strings.ToLower(rawQuality) + if lower == "auto" { + return 240 + } + + digits := extractDigits(lower) + if digits == "" { + return 0 + } + + value, err := strconv.Atoi(digits) + if err != nil { + return 0 + } + return value } func extractDigits(value string) string { @@ -177,13 +159,10 @@ func extractDigits(value string) string { for _, char := range value { if char >= '0' && char <= '9' { digits.WriteRune(char) - continue - } - if digits.Len() > 0 { + } else if digits.Len() > 0 { break } } - return digits.String() } diff --git a/internal/features/watchlist/service.go b/internal/features/watchlist/service.go index c4b35c6..66e977e 100644 --- a/internal/features/watchlist/service.go +++ b/internal/features/watchlist/service.go @@ -133,24 +133,23 @@ func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, ani return ErrInvalidAnimeID } + params := database.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + } + + clearProgress := database.SaveWatchProgressParams{ + CurrentEpisode: sql.NullInt64{Valid: false}, + CurrentTimeSeconds: 0, + UserID: userID, + AnimeID: animeID, + } + if s.sqlDB == nil { - if err := s.db.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{ - UserID: userID, - AnimeID: animeID, - }); err != nil { + if err := s.db.DeleteContinueWatchingEntry(ctx, params); err != nil { return fmt.Errorf("failed to delete continue watching entry: %w", err) } - - if err := s.db.SaveWatchProgress(ctx, database.SaveWatchProgressParams{ - CurrentEpisode: sql.NullInt64{Valid: false}, - CurrentTimeSeconds: 0, - UserID: userID, - AnimeID: animeID, - }); err != nil { - return fmt.Errorf("failed to clear watchlist progress: %w", err) - } - - return nil + return s.db.SaveWatchProgress(ctx, clearProgress) } tx, err := s.sqlDB.BeginTx(ctx, nil) @@ -160,27 +159,14 @@ func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, ani defer tx.Rollback() txQueries := database.New(tx) - if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{ - UserID: userID, - AnimeID: animeID, - }); err != nil { + if err := txQueries.DeleteContinueWatchingEntry(ctx, params); err != nil { return fmt.Errorf("failed to delete continue watching entry: %w", err) } - - if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{ - CurrentEpisode: sql.NullInt64{Valid: false}, - CurrentTimeSeconds: 0, - UserID: userID, - AnimeID: animeID, - }); err != nil { + if err := txQueries.SaveWatchProgress(ctx, clearProgress); err != nil { return fmt.Errorf("failed to clear watchlist progress: %w", err) } - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit continue watching deletion: %w", err) - } - - return nil + return tx.Commit() } type ExportEntry struct {