diff --git a/README.md b/README.md index 62c97da..d29a44f 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ The frontend pipeline uses a single source stylesheet (`static/style.css`) and T When the server starts, the app is available at `http://localhost:3000`. Important notes: -- Environment variables are read directly from the process environment (`PORT`, `DATABASE_FILE`, `ENV`); `.env` is not auto-loaded. +- Environment variables are read directly from the process environment (`PORT`, `DATABASE_FILE`, `ENV`, `PLAYBACK_PROXY_SECRET`); `.env` is not auto-loaded. - The web app currently exposes a login route only. If your database has no users yet, create the first user outside the web UI. For containerized usage, the included `Dockerfile` uses a multi-stage build that installs Bun + templ, builds assets, generates templates, compiles `cmd/server`, and ships a slim runtime image with SQLite support. @@ -125,6 +125,8 @@ docker run --rm \ | `PORT` | `3000` | HTTP listen port | | `DATABASE_FILE` | `mal.db` | SQLite database file path | | `ENV` | _(empty)_ | Set to `production` to enable secure session cookies | +| `MIGRATIONS_DIR` | _(auto-discovered)_ | Optional explicit path to migration files | +| `PLAYBACK_PROXY_SECRET` | _(required)_ | HMAC secret for signed playback proxy tokens (min 32 chars) | ## Database and testing diff --git a/cmd/server/main.go b/cmd/server/main.go index 3c96946..c95f1cd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -3,10 +3,14 @@ package main import ( "context" "database/sql" + "errors" + "fmt" "log" "net/http" "os" "os/signal" + "path/filepath" + "strings" "syscall" "time" @@ -20,23 +24,49 @@ import ( ) func main() { - - dbFile := os.Getenv("DATABASE_FILE") + dbFile := strings.TrimSpace(os.Getenv("DATABASE_FILE")) if dbFile == "" { dbFile = "mal.db" } - db, err := sql.Open("sqlite3", dbFile) + dsn := fmt.Sprintf("file:%s?_foreign_keys=on", dbFile) + db, err := sql.Open("sqlite3", dsn) if err != nil { log.Fatalf("failed to open db: %v", err) } defer db.Close() + if err := db.Ping(); err != nil { + log.Fatalf("failed to ping db: %v", err) + } + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + log.Fatalf("failed to enforce sqlite foreign keys: %v", err) + } + + var fkState int + if err := db.QueryRow("PRAGMA foreign_keys").Scan(&fkState); err != nil { + log.Fatalf("failed to verify sqlite foreign keys: %v", err) + } + if fkState != 1 { + log.Fatal("sqlite foreign keys are disabled") + } + + migrationsDir, err := resolveMigrationsDir() + if err != nil { + log.Fatalf("failed to locate migrations directory: %v", err) + } + // Run migrations with tracking - if err := database.RunMigrations(db); err != nil { + if err := database.RunMigrations(db, migrationsDir); err != nil { log.Fatalf("failed to run migrations: %v", err) } + playbackSecret := strings.TrimSpace(os.Getenv("PLAYBACK_PROXY_SECRET")) + if len(playbackSecret) < 32 { + log.Fatal("PLAYBACK_PROXY_SECRET must be set and at least 32 characters") + } + queries := database.New(db) authService := auth.NewService(queries) jikanClient := jikan.NewClient(queries) @@ -50,9 +80,11 @@ func main() { go relationsWorker.Start(ctx) app := server.Config{ - DB: queries, - JikanClient: jikanClient, - AuthService: authService, + DB: queries, + SQLDB: db, + JikanClient: jikanClient, + AuthService: authService, + PlaybackProxySecret: playbackSecret, } handler := server.NewRouter(app) @@ -86,3 +118,69 @@ func main() { log.Fatalf("Server failed to start: %v", err) } } + +func resolveMigrationsDir() (string, error) { + configured := strings.TrimSpace(os.Getenv("MIGRATIONS_DIR")) + if configured != "" { + hasFiles, err := directoryHasSQLFiles(configured) + if err != nil { + return "", err + } + + if !hasFiles { + return "", fmt.Errorf("MIGRATIONS_DIR has no .sql files: %s", configured) + } + + return configured, nil + } + + workingDir, err := os.Getwd() + if err != nil { + return "", err + } + + executablePath, err := os.Executable() + if err != nil { + return "", err + } + + candidates := []string{ + filepath.Join(workingDir, "migrations"), + filepath.Join(filepath.Dir(executablePath), "migrations"), + } + + for _, candidate := range candidates { + hasFiles, checkErr := directoryHasSQLFiles(candidate) + if checkErr != nil { + if errors.Is(checkErr, os.ErrNotExist) { + continue + } + + return "", checkErr + } + + if hasFiles { + return candidate, nil + } + } + + return "", errors.New("could not find migrations directory") +} + +func directoryHasSQLFiles(dir string) (bool, error) { + info, err := os.Stat(dir) + if err != nil { + return false, err + } + + if !info.IsDir() { + return false, fmt.Errorf("not a directory: %s", dir) + } + + files, err := filepath.Glob(filepath.Join(dir, "*.sql")) + if err != nil { + return false, err + } + + return len(files) > 0, nil +} diff --git a/internal/database/migrate.go b/internal/database/migrate.go index f19e08f..f27af77 100644 --- a/internal/database/migrate.go +++ b/internal/database/migrate.go @@ -2,13 +2,19 @@ package database import ( "database/sql" + "fmt" "log" "os" "path/filepath" "sort" + "strings" ) -func RunMigrations(db *sql.DB) error { +func RunMigrations(db *sql.DB, migrationsDir string) error { + if migrationsDir == "" { + return fmt.Errorf("migrations directory is required") + } + // Create migration tracking table _, err := db.Exec(` CREATE TABLE IF NOT EXISTS migration_version ( @@ -20,21 +26,24 @@ func RunMigrations(db *sql.DB) error { return err } - migrations, err := filepath.Glob("migrations/*.sql") + migrations, err := filepath.Glob(filepath.Join(migrationsDir, "*.sql")) if err != nil { return err } + if len(migrations) == 0 { + return fmt.Errorf("no migration files found in %s", migrationsDir) + } sort.Strings(migrations) + appliedNames, err := loadAppliedMigrationNames(db) + if err != nil { + return err + } + for _, migrationFile := range migrations { - // Check if migration already applied - var exists int - err := db.QueryRow("SELECT COUNT(*) FROM migration_version WHERE name = ?", migrationFile).Scan(&exists) - if err != nil { - return err - } - if exists > 0 { + migrationName := filepath.Base(migrationFile) + if migrationApplied(appliedNames, migrationName) { // already applied, skipping silently continue } @@ -51,13 +60,58 @@ func RunMigrations(db *sql.DB) error { } // Mark as applied - _, err = db.Exec("INSERT INTO migration_version (name) VALUES (?)", migrationFile) + _, err = db.Exec("INSERT INTO migration_version (name) VALUES (?)", migrationName) if err != nil { return err } - log.Printf("migration %s applied successfully", migrationFile) + appliedNames[migrationName] = struct{}{} + + log.Printf("migration %s applied successfully", migrationName) } return nil } + +func loadAppliedMigrationNames(db *sql.DB) (map[string]struct{}, error) { + rows, err := db.Query("SELECT name FROM migration_version") + if err != nil { + return nil, err + } + defer rows.Close() + + applied := make(map[string]struct{}) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + + applied[name] = struct{}{} + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return applied, nil +} + +func migrationApplied(appliedNames map[string]struct{}, migrationName string) bool { + if _, exists := appliedNames[migrationName]; exists { + return true + } + + legacyName := filepath.ToSlash(filepath.Join("migrations", migrationName)) + if _, exists := appliedNames[legacyName]; exists { + return true + } + + for appliedName := range appliedNames { + if strings.EqualFold(filepath.Base(appliedName), migrationName) { + return true + } + } + + return false +} diff --git a/internal/features/playback/handler.go b/internal/features/playback/handler.go index 970eb21..2c50836 100644 --- a/internal/features/playback/handler.go +++ b/internal/features/playback/handler.go @@ -13,8 +13,6 @@ import ( "strings" "time" - "github.com/google/uuid" - "mal/internal/database" "mal/internal/jikan" "mal/internal/shared/middleware" @@ -162,14 +160,12 @@ func convertModeSources(sources map[string]ModeSource) map[string]templates.Mode subtitles := make([]templates.SubtitleItem, len(v.Subtitles)) for i, s := range v.Subtitles { subtitles[i] = templates.SubtitleItem{ - Lang: s.Lang, - URL: s.URL, - Referer: s.Referer, + Lang: s.Lang, + Token: s.Token, } } result[k] = templates.ModeSource{ - URL: v.URL, - Referer: v.Referer, + Token: v.Token, Subtitles: subtitles, } } @@ -199,25 +195,19 @@ func (h *Handler) HandleProxyStream(w http.ResponseWriter, r *http.Request) { mode = "dub" } - state := r.URL.Query().Get("state") - if strings.TrimSpace(state) == "" { - http.Error(w, "missing playback state", http.StatusBadRequest) + token := strings.TrimSpace(r.URL.Query().Get("token")) + if token == "" { + http.Error(w, "missing playback token", http.StatusBadRequest) return } - modeSources := make(map[string]ModeSource) - if err := json.Unmarshal([]byte(state), &modeSources); err != nil { - http.Error(w, "invalid playback state", http.StatusBadRequest) + targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeStream) + if err != nil { + http.Error(w, "invalid stream token", http.StatusBadRequest) return } - source, ok := modeSources[mode] - if !ok || strings.TrimSpace(source.URL) == "" { - http.Error(w, "stream mode unavailable", http.StatusBadRequest) - return - } - - h.proxyUpstream(w, r, source.URL, source.Referer) + h.proxyUpstream(w, r, targetURL, referer) } func (h *Handler) HandleProxySegment(w http.ResponseWriter, r *http.Request) { @@ -226,13 +216,19 @@ func (h *Handler) HandleProxySegment(w http.ResponseWriter, r *http.Request) { return } - targetURL := r.URL.Query().Get("u") - if strings.TrimSpace(targetURL) == "" { - http.Error(w, "missing target url", http.StatusBadRequest) + token := strings.TrimSpace(r.URL.Query().Get("token")) + if token == "" { + http.Error(w, "missing segment token", http.StatusBadRequest) return } - h.proxyUpstream(w, r, targetURL, r.URL.Query().Get("r")) + targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeSegment) + if err != nil { + http.Error(w, "invalid segment token", http.StatusBadRequest) + return + } + + h.proxyUpstream(w, r, targetURL, referer) } func (h *Handler) HandleProxySubtitle(w http.ResponseWriter, r *http.Request) { @@ -241,13 +237,19 @@ func (h *Handler) HandleProxySubtitle(w http.ResponseWriter, r *http.Request) { return } - targetURL := r.URL.Query().Get("u") - if strings.TrimSpace(targetURL) == "" { - http.Error(w, "missing target url", http.StatusBadRequest) + token := strings.TrimSpace(r.URL.Query().Get("token")) + if token == "" { + http.Error(w, "missing subtitle token", http.StatusBadRequest) return } - h.proxyUpstream(w, r, targetURL, r.URL.Query().Get("r")) + targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeSubtitle) + if err != nil { + http.Error(w, "invalid subtitle token", http.StatusBadRequest) + return + } + + h.proxyUpstream(w, r, targetURL, referer) } func (h *Handler) HandleSaveProgress(w http.ResponseWriter, r *http.Request) { @@ -291,62 +293,15 @@ func (h *Handler) HandleSaveProgress(w http.ResponseWriter, r *http.Request) { animeID := int64(payload.MalID) - if _, err := h.svc.db.GetAnime(r.Context(), animeID); err != nil { - anime, fetchErr := h.jikanClient.GetAnimeByID(r.Context(), payload.MalID) - if fetchErr != nil { - log.Printf("save progress failed to fetch anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, fetchErr) - http.Error(w, "failed to save progress", http.StatusInternalServerError) - return - } - - if _, upsertErr := h.svc.db.UpsertAnime(r.Context(), database.UpsertAnimeParams{ - ID: animeID, - TitleOriginal: anime.Title, - TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""}, - TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""}, - ImageUrl: anime.ImageURL(), - Airing: sql.NullBool{Bool: anime.Airing, Valid: true}, - }); upsertErr != nil { - log.Printf("save progress failed to upsert anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, upsertErr) - http.Error(w, "failed to save progress", http.StatusInternalServerError) - return - } - } - - watchListEntry, watchListErr := h.svc.db.GetWatchListEntry(r.Context(), database.GetWatchListEntryParams{ - UserID: user.ID, - AnimeID: animeID, - }) - if watchListErr == nil && watchListEntry.Status == "completed" { - if err := h.svc.db.DeleteContinueWatchingEntry(r.Context(), database.DeleteContinueWatchingEntryParams{ - UserID: user.ID, - AnimeID: animeID, - }); err != nil { - log.Printf("save progress failed to clear continue entry user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) - } - w.WriteHeader(http.StatusNoContent) + animeSeed, err := h.ensureAnimeSeed(r.Context(), payload.MalID) + if err != nil { + log.Printf("save progress failed to resolve anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) + http.Error(w, "failed to save progress", http.StatusInternalServerError) return } - if err := h.svc.db.SaveWatchProgress(r.Context(), database.SaveWatchProgressParams{ - CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true}, - CurrentTimeSeconds: timeSeconds, - UserID: user.ID, - AnimeID: animeID, - }); err != nil { - if err.Error() != "sql: no rows in result set" { - log.Printf("save watchlist progress skipped user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) - } - } - - if _, err := h.svc.db.UpsertContinueWatchingEntry(r.Context(), database.UpsertContinueWatchingEntryParams{ - ID: uuid.New().String(), - UserID: user.ID, - AnimeID: animeID, - CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true}, - CurrentTimeSeconds: timeSeconds, - }); err != nil { - log.Printf("save continue watching failed user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) + if err := h.svc.SaveProgress(r.Context(), user.ID, animeID, payload.Episode, timeSeconds, animeSeed); err != nil { + log.Printf("save progress failed user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) http.Error(w, "failed to save progress", http.StatusInternalServerError) return } @@ -383,91 +338,43 @@ func (h *Handler) HandleCompleteAnime(w http.ResponseWriter, r *http.Request) { } animeID := int64(payload.MalID) - watchListEntry, watchListErr := h.svc.db.GetWatchListEntry(r.Context(), database.GetWatchListEntryParams{ - UserID: user.ID, - AnimeID: animeID, - }) - if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) { - log.Printf("complete anime failed to load watchlist user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, watchListErr) + animeSeed, err := h.ensureAnimeSeed(r.Context(), payload.MalID) + if err != nil { + log.Printf("complete anime failed to resolve anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) return } - alreadyCompleted := watchListErr == nil && watchListEntry.Status == "completed" - - if !alreadyCompleted { - if _, err := h.svc.db.GetAnime(r.Context(), animeID); err != nil { - anime, fetchErr := h.jikanClient.GetAnimeByID(r.Context(), payload.MalID) - if fetchErr != nil { - log.Printf("complete anime failed to fetch anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, fetchErr) - http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) - return - } - - if _, upsertErr := h.svc.db.UpsertAnime(r.Context(), database.UpsertAnimeParams{ - ID: animeID, - TitleOriginal: anime.Title, - TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""}, - TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""}, - ImageUrl: anime.ImageURL(), - Airing: sql.NullBool{Bool: anime.Airing, Valid: true}, - }); upsertErr != nil { - log.Printf("complete anime failed to upsert anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, upsertErr) - http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) - return - } - } - - if _, err := h.svc.db.UpsertWatchListEntry(r.Context(), database.UpsertWatchListEntryParams{ - ID: uuid.New().String(), - UserID: user.ID, - AnimeID: animeID, - Status: "completed", - CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true}, - CurrentTimeSeconds: 0, - }); err != nil { - log.Printf("complete anime failed to upsert watchlist user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) - http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) - return - } - } - - if err := h.svc.db.DeleteContinueWatchingEntry(r.Context(), database.DeleteContinueWatchingEntryParams{ - UserID: user.ID, - AnimeID: animeID, - }); err != nil { - log.Printf("complete anime failed to delete continue entry user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) + if err := h.svc.CompleteAnime(r.Context(), user.ID, animeID, payload.Episode, animeSeed); err != nil { + log.Printf("complete anime failed user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) return } - if _, err := h.svc.db.GetContinueWatchingEntry(r.Context(), database.GetContinueWatchingEntryParams{ - UserID: user.ID, - AnimeID: animeID, - }); err == nil { - log.Printf("complete anime failed to clear continue entry user_id=%s mal_id=%d", user.ID, payload.MalID) - http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) - return - } else if !errors.Is(err, sql.ErrNoRows) { - log.Printf("complete anime failed to verify continue clear user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) - http.Error(w, "failed to mark anime completed", http.StatusInternalServerError) - return - } - - if err := h.svc.db.SaveWatchProgress(r.Context(), database.SaveWatchProgressParams{ - CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true}, - CurrentTimeSeconds: 0, - UserID: user.ID, - AnimeID: animeID, - }); err != nil { - if !errors.Is(err, sql.ErrNoRows) { - log.Printf("complete anime failed to reset watchlist progress user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err) - } - } - w.WriteHeader(http.StatusNoContent) } +func (h *Handler) ensureAnimeSeed(ctx context.Context, malID int) (*database.UpsertAnimeParams, error) { + animeID := int64(malID) + if _, err := h.svc.db.GetAnime(ctx, animeID); err == nil { + return nil, nil + } + + anime, err := h.jikanClient.GetAnimeByID(ctx, malID) + if err != nil { + return nil, err + } + + return &database.UpsertAnimeParams{ + ID: animeID, + TitleOriginal: anime.Title, + TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""}, + TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""}, + ImageUrl: anime.ImageURL(), + Airing: sql.NullBool{Bool: anime.Airing, Valid: true}, + }, nil +} + func (h *Handler) proxyUpstream(w http.ResponseWriter, r *http.Request, targetURL string, referer string) { parsed, err := url.Parse(targetURL) if err != nil || (parsed.Scheme != "http" && parsed.Scheme != "https") { diff --git a/internal/features/playback/progress.go b/internal/features/playback/progress.go new file mode 100644 index 0000000..2330f53 --- /dev/null +++ b/internal/features/playback/progress.go @@ -0,0 +1,157 @@ +package playback + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/google/uuid" + + "mal/internal/database" +) + +func (s *Service) SaveProgress(ctx context.Context, userID string, animeID int64, episode int, timeSeconds float64, animeSeed *database.UpsertAnimeParams) error { + if strings.TrimSpace(userID) == "" || animeID <= 0 || episode <= 0 { + return errors.New("invalid save progress input") + } + + txQueries, tx, err := s.beginTx(ctx) + if err != nil { + return err + } + + defer tx.Rollback() + + if animeSeed != nil { + if _, err := txQueries.UpsertAnime(ctx, *animeSeed); err != nil { + return fmt.Errorf("failed to save anime reference: %w", err) + } + } + + watchListEntry, watchListErr := txQueries.GetWatchListEntry(ctx, database.GetWatchListEntryParams{ + UserID: userID, + AnimeID: animeID, + }) + if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) { + return fmt.Errorf("failed to load watchlist entry: %w", watchListErr) + } + + if watchListErr == nil && watchListEntry.Status == "completed" { + if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + }); err != nil { + return fmt.Errorf("failed to clear continue entry: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit save progress transaction: %w", err) + } + + return nil + } + + if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{ + CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true}, + CurrentTimeSeconds: timeSeconds, + UserID: userID, + AnimeID: animeID, + }); err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to save watchlist progress: %w", err) + } + + if _, err := txQueries.UpsertContinueWatchingEntry(ctx, database.UpsertContinueWatchingEntryParams{ + ID: uuid.New().String(), + UserID: userID, + AnimeID: animeID, + CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true}, + CurrentTimeSeconds: timeSeconds, + }); err != nil { + return fmt.Errorf("failed to upsert continue entry: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit save progress transaction: %w", err) + } + + return nil +} + +func (s *Service) CompleteAnime(ctx context.Context, userID string, animeID int64, episode int, animeSeed *database.UpsertAnimeParams) error { + if strings.TrimSpace(userID) == "" || animeID <= 0 || episode <= 0 { + return errors.New("invalid complete anime input") + } + + txQueries, tx, err := s.beginTx(ctx) + if err != nil { + return err + } + + defer tx.Rollback() + + watchListEntry, watchListErr := txQueries.GetWatchListEntry(ctx, database.GetWatchListEntryParams{ + UserID: userID, + AnimeID: animeID, + }) + if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) { + return fmt.Errorf("failed to load watchlist entry: %w", watchListErr) + } + + alreadyCompleted := watchListErr == nil && watchListEntry.Status == "completed" + + if !alreadyCompleted { + if animeSeed != nil { + if _, err := txQueries.UpsertAnime(ctx, *animeSeed); err != nil { + return fmt.Errorf("failed to save anime reference: %w", err) + } + } + + if _, err := txQueries.UpsertWatchListEntry(ctx, database.UpsertWatchListEntryParams{ + ID: uuid.New().String(), + UserID: userID, + AnimeID: animeID, + Status: "completed", + CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true}, + CurrentTimeSeconds: 0, + }); err != nil { + return fmt.Errorf("failed to mark watchlist as completed: %w", err) + } + } + + if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + }); err != nil { + return fmt.Errorf("failed to clear continue entry: %w", err) + } + + if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{ + CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true}, + CurrentTimeSeconds: 0, + UserID: userID, + AnimeID: animeID, + }); err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("failed to reset watch progress: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit complete anime transaction: %w", err) + } + + return nil +} + +func (s *Service) beginTx(ctx context.Context) (*database.Queries, *sql.Tx, error) { + if s.sqlDB == nil { + return nil, nil, errors.New("database unavailable") + } + + tx, err := s.sqlDB.BeginTx(ctx, nil) + if err != nil { + return nil, nil, fmt.Errorf("failed to begin transaction: %w", err) + } + + return database.New(tx), tx, nil +} diff --git a/internal/features/playback/proxy_security.go b/internal/features/playback/proxy_security.go new file mode 100644 index 0000000..a494e43 --- /dev/null +++ b/internal/features/playback/proxy_security.go @@ -0,0 +1,348 @@ +package playback + +import ( + "bufio" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +const ( + proxyStreamTokenTTL = 2 * time.Hour + proxySegmentTokenTTL = 6 * time.Hour + proxySubtitleTokenTTL = 6 * time.Hour +) +const proxyHostCheckTTL = 2 * time.Minute + +type proxyScope string + +const ( + proxyScopeStream proxyScope = "stream" + proxyScopeSegment proxyScope = "segment" + proxyScopeSubtitle proxyScope = "subtitle" +) + +type proxyTokenPayload struct { + TargetURL string `json:"u"` + Referer string `json:"r,omitempty"` + Scope string `json:"s"` + ExpiresAt int64 `json:"exp"` +} + +type proxyTokenSigner struct { + secret []byte +} + +func newProxyTokenSigner(secret string) (*proxyTokenSigner, error) { + trimmed := strings.TrimSpace(secret) + if trimmed == "" { + return nil, errors.New("proxy token secret is required") + } + + if len(trimmed) < 32 { + return nil, errors.New("proxy token secret must be at least 32 characters") + } + + return &proxyTokenSigner{secret: []byte(trimmed)}, nil +} + +func (s *proxyTokenSigner) Sign(payload proxyTokenPayload) (string, error) { + body, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshal proxy token payload: %w", err) + } + + mac := hmac.New(sha256.New, s.secret) + mac.Write(body) + signature := mac.Sum(nil) + + encodedBody := base64.RawURLEncoding.EncodeToString(body) + encodedSignature := base64.RawURLEncoding.EncodeToString(signature) + return encodedBody + "." + encodedSignature, nil +} + +func (s *proxyTokenSigner) Verify(token string) (proxyTokenPayload, error) { + parts := strings.Split(token, ".") + if len(parts) != 2 { + return proxyTokenPayload{}, errors.New("invalid proxy token format") + } + + body, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return proxyTokenPayload{}, errors.New("invalid proxy token payload") + } + + signature, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return proxyTokenPayload{}, errors.New("invalid proxy token signature") + } + + mac := hmac.New(sha256.New, s.secret) + mac.Write(body) + expected := mac.Sum(nil) + if !hmac.Equal(signature, expected) { + return proxyTokenPayload{}, errors.New("invalid proxy token signature") + } + + var payload proxyTokenPayload + if err := json.Unmarshal(body, &payload); err != nil { + return proxyTokenPayload{}, errors.New("invalid proxy token payload") + } + + if payload.ExpiresAt <= time.Now().Unix() { + return proxyTokenPayload{}, errors.New("proxy token expired") + } + + return payload, nil +} + +func (s *Service) buildClientModeSources(modeSources map[string]ModeSource) (map[string]ModeSource, error) { + clientModeSources := make(map[string]ModeSource, len(modeSources)) + + for mode, source := range modeSources { + streamToken, err := s.issueProxyToken(source.URL, source.Referer, proxyScopeStream) + if err != nil { + return nil, err + } + + subtitles := make([]SubtitleItem, 0, len(source.Subtitles)) + for _, subtitle := range source.Subtitles { + targetURL := strings.TrimSpace(subtitle.URL) + if targetURL == "" { + continue + } + + token, err := s.issueProxyToken(targetURL, source.Referer, proxyScopeSubtitle) + if err != nil { + return nil, err + } + + subtitles = append(subtitles, SubtitleItem{ + Lang: subtitle.Lang, + Token: token, + }) + } + + clientModeSources[mode] = ModeSource{ + Token: streamToken, + Subtitles: subtitles, + } + } + + return clientModeSources, nil +} + +func (s *Service) issueProxyToken(targetURL string, referer string, scope proxyScope) (string, error) { + normalizedTarget, err := normalizeProxyURL(targetURL) + if err != nil { + return "", err + } + + normalizedReferer := "" + if strings.TrimSpace(referer) != "" { + refererURL, refererErr := normalizeProxyURL(referer) + if refererErr == nil { + normalizedReferer = refererURL + } + } + + return s.proxyTokens.Sign(proxyTokenPayload{ + TargetURL: normalizedTarget, + Referer: normalizedReferer, + Scope: string(scope), + ExpiresAt: time.Now().Add(proxyTokenTTL(scope)).Unix(), + }) +} + +func proxyTokenTTL(scope proxyScope) time.Duration { + switch scope { + case proxyScopeStream: + return proxyStreamTokenTTL + case proxyScopeSegment: + return proxySegmentTokenTTL + case proxyScopeSubtitle: + return proxySubtitleTokenTTL + default: + return proxyStreamTokenTTL + } +} + +func (s *Service) resolveProxyToken(ctx context.Context, token string, scope proxyScope) (string, string, error) { + payload, err := s.proxyTokens.Verify(token) + if err != nil { + return "", "", err + } + + if payload.Scope != string(scope) { + return "", "", errors.New("proxy token scope mismatch") + } + + normalizedTarget, err := normalizeProxyURL(payload.TargetURL) + if err != nil { + return "", "", err + } + + if err := s.ensurePublicProxyTarget(ctx, normalizedTarget); err != nil { + return "", "", err + } + + normalizedReferer := "" + if strings.TrimSpace(payload.Referer) != "" { + refererURL, refererErr := normalizeProxyURL(payload.Referer) + if refererErr == nil { + if ensureErr := s.ensurePublicProxyTarget(ctx, refererURL); ensureErr == nil { + normalizedReferer = refererURL + } + } + } + + return normalizedTarget, normalizedReferer, nil +} + +func normalizeProxyURL(rawURL string) (string, error) { + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return "", errors.New("invalid proxy target") + } + + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", errors.New("invalid proxy target scheme") + } + + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return "", errors.New("invalid proxy target host") + } + + if host == "localhost" || strings.HasSuffix(host, ".localhost") || strings.HasSuffix(host, ".local") { + return "", errors.New("localhost targets are not allowed") + } + + ip := net.ParseIP(host) + if ip != nil && isBlockedProxyIP(ip) { + return "", errors.New("private proxy targets are not allowed") + } + + return parsed.String(), nil +} + +func isBlockedProxyIP(ip net.IP) bool { + return ip.IsLoopback() || + ip.IsPrivate() || + ip.IsMulticast() || + ip.IsLinkLocalMulticast() || + ip.IsLinkLocalUnicast() || + ip.IsUnspecified() +} + +func (s *Service) ensurePublicProxyTarget(ctx context.Context, rawURL string) error { + parsed, err := url.Parse(rawURL) + if err != nil { + return errors.New("invalid proxy target") + } + + host := strings.TrimSpace(parsed.Hostname()) + if host == "" { + return errors.New("invalid proxy target host") + } + + if ip := net.ParseIP(host); ip != nil { + if isBlockedProxyIP(ip) { + return errors.New("private proxy targets are not allowed") + } + return nil + } + + now := time.Now() + s.proxyHostMu.RLock() + cached, ok := s.proxyHostCache[host] + s.proxyHostMu.RUnlock() + if ok && now.Before(cached.ExpiresAt) { + if cached.Allowed { + return nil + } + return errors.New("private proxy targets are not allowed") + } + + resolvedIPs, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil || len(resolvedIPs) == 0 { + return errors.New("proxy target lookup failed") + } + + allowed := true + for _, resolved := range resolvedIPs { + if isBlockedProxyIP(resolved.IP) { + allowed = false + break + } + } + + s.proxyHostMu.Lock() + s.proxyHostCache[host] = proxyHostCacheItem{ + Allowed: allowed, + ExpiresAt: now.Add(proxyHostCheckTTL), + } + s.proxyHostMu.Unlock() + + if !allowed { + return errors.New("private proxy targets are not allowed") + } + + return nil +} + +func (s *Service) rewritePlaylistWithTokens(ctx context.Context, content string, baseURL string, referer string) (string, error) { + base, err := url.Parse(baseURL) + if err != nil { + return "", err + } + + var out strings.Builder + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + + line := scanner.Text() + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + out.WriteString(line) + out.WriteString("\n") + continue + } + + relativeURL, parseErr := url.Parse(trimmed) + if parseErr != nil { + out.WriteString(line) + out.WriteString("\n") + continue + } + + absoluteURL := base.ResolveReference(relativeURL).String() + token, tokenErr := s.issueProxyToken(absoluteURL, referer, proxyScopeSegment) + if tokenErr != nil { + return "", tokenErr + } + + proxied := "/watch/proxy/segment?token=" + url.QueryEscape(token) + out.WriteString(proxied) + out.WriteString("\n") + } + + if err := scanner.Err(); err != nil { + return "", err + } + + return out.String(), nil +} diff --git a/internal/features/playback/proxy_security_test.go b/internal/features/playback/proxy_security_test.go new file mode 100644 index 0000000..9c09ea7 --- /dev/null +++ b/internal/features/playback/proxy_security_test.go @@ -0,0 +1,45 @@ +package playback + +import ( + "context" + "testing" + + "mal/internal/database" +) + +func TestNormalizeProxyURLRejectsLocalhost(t *testing.T) { + t.Parallel() + + _, err := normalizeProxyURL("http://localhost:8080/private") + if err == nil { + t.Fatal("expected localhost URL to be rejected") + } +} + +func TestNormalizeProxyURLRejectsPrivateIP(t *testing.T) { + t.Parallel() + + _, err := normalizeProxyURL("http://192.168.1.10/stream") + if err == nil { + t.Fatal("expected private IP URL to be rejected") + } +} + +func TestProxyTokenScopeValidation(t *testing.T) { + t.Parallel() + + service := NewService(&fakeProxyQuerier{}, nil, Config{ProxyTokenSecret: "0123456789abcdef0123456789abcdef"}) + token, err := service.issueProxyToken("https://example.com/playlist.m3u8", "", proxyScopeStream) + if err != nil { + t.Fatalf("failed to issue token: %v", err) + } + + _, _, err = service.resolveProxyToken(context.Background(), token, proxyScopeSegment) + if err == nil { + t.Fatal("expected scope mismatch error") + } +} + +type fakeProxyQuerier struct { + database.Querier +} diff --git a/internal/features/playback/service.go b/internal/features/playback/service.go deleted file mode 100644 index cf0e06c..0000000 --- a/internal/features/playback/service.go +++ /dev/null @@ -1,1174 +0,0 @@ -package playback - -import ( - "bufio" - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "mal/internal/database" - "net/http" - "net/url" - "sort" - "strconv" - "strings" - "sync" - "time" -) - -const ( - showResolutionCacheTTL = 12 * time.Hour - playbackDataCacheTTL = 2 * time.Minute - providerProbeTimeout = 3 * time.Second -) - -type Service struct { - allAnimeClient *allAnimeClient - httpClient *http.Client - db database.Querier - - cacheMu sync.RWMutex - showResolution map[int]showResolutionCacheItem - playbackDataCache map[string]playbackDataCacheItem -} - -type sourceScore struct { - source StreamSource - total int - typeScore int - providerScore int - qualityScore int - refererScore int -} - -type showResolutionCacheItem struct { - ShowID string - Title string - ExpiresAt time.Time -} - -type playbackDataCacheItem struct { - Data playbackBaseData - ExpiresAt time.Time -} - -type playbackBaseData struct { - Title string - AvailableModes []string - ModeSources map[string]ModeSource - Segments []SkipSegment -} - -type modeSourceResult struct { - Mode string - Source ModeSource - OK bool -} - -type searchModeResult struct { - Mode string - Results []searchResult - Err error -} - -type directProbeResult struct { - Playable bool - ContentType string -} - -type userPlaybackState struct { - CurrentStatus string - StartTimeSeconds float64 -} - -func NewService(db database.Querier) *Service { - return &Service{ - allAnimeClient: newAllAnimeClient(), - httpClient: &http.Client{Timeout: 12 * time.Second}, - db: db, - showResolution: make(map[int]showResolutionCacheItem), - playbackDataCache: make(map[string]playbackDataCacheItem), - } -} - -func (s *Service) BuildWatchPageData(ctx context.Context, malID int, titleCandidates []string, episode string, mode string, userID string) (WatchPageData, error) { - if malID <= 0 { - return WatchPageData{}, errors.New("invalid mal id") - } - - normalizedMode := normalizeMode(mode) - if normalizedMode == "" { - normalizedMode = "dub" - } - - normalizedEpisode := strings.TrimSpace(episode) - if normalizedEpisode == "" { - normalizedEpisode = "1" - } - - userStateCh := s.fetchUserPlaybackStateAsync(ctx, userID, malID, normalizedEpisode) - - cacheKey := playbackDataCacheKey(malID, normalizedEpisode) - baseData, cacheHit := s.getPlaybackBaseDataCache(cacheKey) - if !cacheHit { - showID, resolvedTitle, err := s.resolveShowCached(ctx, malID, titleCandidates) - if err != nil { - return WatchPageData{}, err - } - - modeSources, segments := s.fetchPlaybackSourcesAndSegments(ctx, showID, malID, normalizedEpisode) - if len(modeSources) == 0 { - return WatchPageData{}, errors.New("no direct playable sources available") - } - - watchTitle := strings.TrimSpace(resolvedTitle) - if watchTitle == "" { - watchTitle = firstNonEmptyTitle(titleCandidates) - } - if watchTitle == "" { - watchTitle = fmt.Sprintf("MAL #%d", malID) - } - - baseData = playbackBaseData{ - Title: watchTitle, - AvailableModes: availableModes(modeSources), - ModeSources: modeSources, - Segments: segments, - } - - s.setPlaybackBaseDataCache(cacheKey, baseData) - } - - initialMode := selectInitialMode(normalizedMode, baseData.ModeSources) - - userState := userPlaybackState{} - if userStateCh != nil { - userState = <-userStateCh - } - - return WatchPageData{ - MalID: malID, - Title: baseData.Title, - CurrentEpisode: normalizedEpisode, - StartTimeSeconds: userState.StartTimeSeconds, - CurrentStatus: userState.CurrentStatus, - InitialMode: initialMode, - AvailableModes: cloneStringSlice(baseData.AvailableModes), - ModeSources: cloneModeSources(baseData.ModeSources), - Segments: cloneSegments(baseData.Segments), - }, nil -} - -func playbackDataCacheKey(malID int, episode string) string { - return fmt.Sprintf("%d:%s", malID, episode) -} - -func (s *Service) fetchUserPlaybackStateAsync(ctx context.Context, userID string, malID int, episode string) <-chan userPlaybackState { - if userID == "" || s.db == nil { - return nil - } - - resultCh := make(chan userPlaybackState, 1) - go func() { - state := userPlaybackState{} - - entry, err := s.db.GetWatchListEntry(ctx, database.GetWatchListEntryParams{ - UserID: userID, - AnimeID: int64(malID), - }) - if err == nil { - state.CurrentStatus = entry.Status - if entry.CurrentEpisode.Valid && strconv.FormatInt(entry.CurrentEpisode.Int64, 10) == episode && entry.CurrentTimeSeconds > 0 { - state.StartTimeSeconds = entry.CurrentTimeSeconds - } - } - - if state.StartTimeSeconds <= 0 { - continueEntry, continueErr := s.db.GetContinueWatchingEntry(ctx, database.GetContinueWatchingEntryParams{ - UserID: userID, - AnimeID: int64(malID), - }) - if continueErr == nil && continueEntry.CurrentEpisode.Valid && strconv.FormatInt(continueEntry.CurrentEpisode.Int64, 10) == episode && continueEntry.CurrentTimeSeconds > 0 { - state.StartTimeSeconds = continueEntry.CurrentTimeSeconds - } - } - - resultCh <- state - }() - - return resultCh -} - -func (s *Service) getPlaybackBaseDataCache(key string) (playbackBaseData, bool) { - now := time.Now() - - s.cacheMu.RLock() - item, ok := s.playbackDataCache[key] - s.cacheMu.RUnlock() - if !ok { - return playbackBaseData{}, false - } - - if now.After(item.ExpiresAt) { - s.cacheMu.Lock() - current, exists := s.playbackDataCache[key] - if exists && time.Now().After(current.ExpiresAt) { - delete(s.playbackDataCache, key) - } - s.cacheMu.Unlock() - return playbackBaseData{}, false - } - - return clonePlaybackBaseData(item.Data), true -} - -func (s *Service) setPlaybackBaseDataCache(key string, data playbackBaseData) { - s.cacheMu.Lock() - s.playbackDataCache[key] = playbackDataCacheItem{ - Data: clonePlaybackBaseData(data), - ExpiresAt: time.Now().Add(playbackDataCacheTTL), - } - s.cacheMu.Unlock() -} - -func (s *Service) resolveShowCached(ctx context.Context, malID int, titleCandidates []string) (string, string, error) { - now := time.Now() - - s.cacheMu.RLock() - item, ok := s.showResolution[malID] - s.cacheMu.RUnlock() - - if ok && now.Before(item.ExpiresAt) && strings.TrimSpace(item.ShowID) != "" { - return item.ShowID, item.Title, nil - } - - showID, resolvedTitle, err := s.resolveShow(ctx, malID, titleCandidates) - if err != nil { - return "", "", err - } - - s.cacheMu.Lock() - s.showResolution[malID] = showResolutionCacheItem{ - ShowID: showID, - Title: resolvedTitle, - ExpiresAt: time.Now().Add(showResolutionCacheTTL), - } - s.cacheMu.Unlock() - - return showID, resolvedTitle, nil -} - -func (s *Service) fetchPlaybackSourcesAndSegments(ctx context.Context, showID string, malID int, episode string) (map[string]ModeSource, []SkipSegment) { - modeCh := make(chan modeSourceResult, 2) - probeCache := make(map[string]directProbeResult) - probeCacheMu := sync.Mutex{} - - for _, mode := range []string{"dub", "sub"} { - modeValue := mode - go func() { - resolved, err := s.resolveModeSourceWithCache(ctx, showID, episode, modeValue, "best", probeCache, &probeCacheMu) - if err != nil { - modeCh <- modeSourceResult{Mode: modeValue, OK: false} - return - } - - if strings.ToLower(resolved.Type) == "embed" { - modeCh <- modeSourceResult{Mode: modeValue, OK: false} - return - } - - modeCh <- modeSourceResult{ - Mode: modeValue, - Source: ModeSource{ - URL: resolved.URL, - Referer: resolved.Referer, - Subtitles: toSubtitleItems(resolved), - }, - OK: true, - } - }() - } - - segmentsCh := make(chan []SkipSegment, 1) - go func() { - segmentsCh <- s.fetchSkipSegments(ctx, malID, episode) - }() - - modeSources := make(map[string]ModeSource) - for range 2 { - result := <-modeCh - if !result.OK { - continue - } - modeSources[result.Mode] = result.Source - } - - segments := <-segmentsCh - return modeSources, segments -} - -func clonePlaybackBaseData(data playbackBaseData) playbackBaseData { - return playbackBaseData{ - Title: data.Title, - AvailableModes: cloneStringSlice(data.AvailableModes), - ModeSources: cloneModeSources(data.ModeSources), - Segments: cloneSegments(data.Segments), - } -} - -func cloneStringSlice(items []string) []string { - if len(items) == 0 { - return nil - } - - cloned := make([]string, len(items)) - copy(cloned, items) - return cloned -} - -func cloneModeSources(modeSources map[string]ModeSource) map[string]ModeSource { - if len(modeSources) == 0 { - return nil - } - - cloned := make(map[string]ModeSource, len(modeSources)) - for mode, source := range modeSources { - cloned[mode] = ModeSource{ - URL: source.URL, - Referer: source.Referer, - Subtitles: cloneSubtitleItems(source.Subtitles), - } - } - - return cloned -} - -func cloneSubtitleItems(items []SubtitleItem) []SubtitleItem { - if len(items) == 0 { - return nil - } - - cloned := make([]SubtitleItem, len(items)) - copy(cloned, items) - return cloned -} - -func cloneSegments(segments []SkipSegment) []SkipSegment { - if len(segments) == 0 { - return nil - } - - cloned := make([]SkipSegment, len(segments)) - copy(cloned, segments) - return cloned -} - -func (s *Service) resolveShow(ctx context.Context, malID int, titleCandidates []string) (string, string, error) { - malText := strconv.Itoa(malID) - modeCandidates := []string{"sub", "dub"} - queries := buildTitleSearchQueries(titleCandidates) - - for _, query := range queries { - resultsByMode := s.searchShowResultsByMode(ctx, query, modeCandidates) - - for _, mode := range modeCandidates { - for _, result := range resultsByMode[mode] { - if strings.TrimSpace(result.MalID) == malText && strings.TrimSpace(result.ID) != "" { - return result.ID, result.Name, nil - } - } - } - - for _, mode := range modeCandidates { - results := resultsByMode[mode] - if len(results) == 0 { - continue - } - - best := results[0] - if strings.TrimSpace(best.ID) != "" { - return best.ID, best.Name, nil - } - } - } - - return "", "", errors.New("unable to resolve allanime show") -} - -func (s *Service) searchShowResultsByMode(ctx context.Context, query string, modeCandidates []string) map[string][]searchResult { - resultsByMode := make(map[string][]searchResult, len(modeCandidates)) - searchCh := make(chan searchModeResult, len(modeCandidates)) - - var wg sync.WaitGroup - for _, mode := range modeCandidates { - modeValue := mode - wg.Add(1) - go func() { - defer wg.Done() - results, err := s.allAnimeClient.Search(ctx, query, modeValue) - searchCh <- searchModeResult{Mode: modeValue, Results: results, Err: err} - }() - } - - wg.Wait() - close(searchCh) - - for result := range searchCh { - if result.Err != nil { - continue - } - - resultsByMode[result.Mode] = result.Results - } - - return resultsByMode -} - -func buildTitleSearchQueries(titleCandidates []string) []string { - queries := make([]string, 0, len(titleCandidates)*4) - seen := make(map[string]struct{}) - - add := func(raw string) { - normalized := normalizeSearchQuery(raw) - if normalized == "" { - return - } - - key := strings.ToLower(normalized) - if _, exists := seen[key]; exists { - return - } - - seen[key] = struct{}{} - queries = append(queries, normalized) - } - - for _, candidate := range titleCandidates { - normalized := normalizeSearchQuery(candidate) - if normalized == "" { - continue - } - - add(normalized) - add(strings.ReplaceAll(normalized, "+", " ")) - - withoutApostrophes := strings.NewReplacer("'", "", "’", "", "`", "").Replace(normalized) - add(withoutApostrophes) - add(strings.ReplaceAll(withoutApostrophes, "+", " ")) - } - - return queries -} - -func normalizeSearchQuery(raw string) string { - return strings.Join(strings.Fields(strings.TrimSpace(raw)), " ") -} - -func firstNonEmptyTitle(values []string) string { - for _, value := range values { - normalized := strings.TrimSpace(value) - if normalized != "" { - return normalized - } - } - - return "" -} - -func (s *Service) resolveModeSource(ctx context.Context, showID string, episode string, mode string, quality string) (StreamSource, error) { - sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode) - if err != nil { - return StreamSource{}, err - } - - ranked, err := rankSources(sources, quality) - if err != nil { - return StreamSource{}, err - } - - selected, _, err := s.choosePlaybackSource(ctx, ranked) - if err != nil { - return StreamSource{}, err - } - - return selected, nil -} - -func (s *Service) resolveModeSourceWithCache( - ctx context.Context, - showID string, - episode string, - mode string, - quality string, - probeCache map[string]directProbeResult, - probeCacheMu *sync.Mutex, -) (StreamSource, error) { - sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode) - if err != nil { - return StreamSource{}, err - } - - ranked, err := rankSources(sources, quality) - if err != nil { - return StreamSource{}, err - } - - selected, _, err := s.choosePlaybackSourceWithCache(ctx, ranked, probeCache, probeCacheMu) - if err != nil { - return StreamSource{}, err - } - - return selected, nil -} - -func (s *Service) choosePlaybackSource(ctx context.Context, ranked []sourceScore) (StreamSource, string, error) { - if len(ranked) == 0 { - return StreamSource{}, "", errors.New("no ranked sources available") - } - - embedCandidates := make([]StreamSource, 0) - for _, candidate := range ranked { - source := candidate.source - sourceType := strings.ToLower(source.Type) - - switch sourceType { - case "mp4", "m3u8": - return source, "direct-media", nil - case "embed": - embedCandidates = append(embedCandidates, source) - default: - playable, contentType := s.probeDirectMedia(ctx, source) - if playable { - return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil - } - } - } - - for _, embed := range embedCandidates { - if s.probeEmbedSource(ctx, embed) { - return embed, "embed-probed", nil - } - } - - if len(embedCandidates) > 0 { - return embedCandidates[0], "embed-fallback", nil - } - - return ranked[0].source, "ranked-fallback", nil -} - -func (s *Service) choosePlaybackSourceWithCache( - ctx context.Context, - ranked []sourceScore, - probeCache map[string]directProbeResult, - probeCacheMu *sync.Mutex, -) (StreamSource, string, error) { - if len(ranked) == 0 { - return StreamSource{}, "", errors.New("no ranked sources available") - } - - embedCandidates := make([]StreamSource, 0) - for _, candidate := range ranked { - source := candidate.source - sourceType := strings.ToLower(source.Type) - - switch sourceType { - case "mp4", "m3u8": - return source, "direct-media", nil - case "embed": - embedCandidates = append(embedCandidates, source) - default: - playable, contentType := s.probeDirectMediaCached(ctx, source, probeCache, probeCacheMu) - if playable { - return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil - } - } - } - - for _, embed := range embedCandidates { - if s.probeEmbedSource(ctx, embed) { - return embed, "embed-probed", nil - } - } - - if len(embedCandidates) > 0 { - return embedCandidates[0], "embed-fallback", nil - } - - return ranked[0].source, "ranked-fallback", nil -} - -func (s *Service) probeDirectMediaCached( - ctx context.Context, - source StreamSource, - probeCache map[string]directProbeResult, - probeCacheMu *sync.Mutex, -) (bool, string) { - cacheKey := strings.TrimSpace(source.URL) - if cacheKey == "" { - return s.probeDirectMedia(ctx, source) - } - - probeCacheMu.Lock() - cached, ok := probeCache[cacheKey] - probeCacheMu.Unlock() - if ok { - return cached.Playable, cached.ContentType - } - - playable, contentType := s.probeDirectMedia(ctx, source) - - probeCacheMu.Lock() - probeCache[cacheKey] = directProbeResult{Playable: playable, ContentType: contentType} - probeCacheMu.Unlock() - - return playable, contentType -} - -func (s *Service) probeDirectMedia(ctx context.Context, source StreamSource) (bool, string) { - probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout) - defer cancel() - - req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil) - if err != nil { - return false, "" - } - - if source.Referer != "" { - req.Header.Set("Referer", source.Referer) - } - req.Header.Set("User-Agent", defaultUserAgent) - req.Header.Set("Range", "bytes=0-4095") - - resp, err := s.httpClient.Do(req) - if err != nil { - return false, "" - } - defer resp.Body.Close() - - contentType := strings.ToLower(resp.Header.Get("Content-Type")) - if strings.Contains(contentType, "video/") || strings.Contains(contentType, "mpegurl") { - return true, contentType - } - - prefix, err := io.ReadAll(io.LimitReader(resp.Body, 4096)) - if err == nil { - if isLikelyM3U8(prefix) { - return true, "application/vnd.apple.mpegurl" - } - if isLikelyMP4(prefix) { - return true, "video/mp4" - } - } - - finalURL := "" - if resp.Request != nil && resp.Request.URL != nil { - finalURL = strings.ToLower(resp.Request.URL.String()) - } - - if strings.Contains(finalURL, ".mp4") || strings.Contains(finalURL, ".m3u8") { - return true, contentType - } - - return false, contentType -} - -func (s *Service) probeEmbedSource(ctx context.Context, source StreamSource) bool { - probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout) - defer cancel() - - req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil) - if err != nil { - return false - } - - if source.Referer != "" { - req.Header.Set("Referer", source.Referer) - } - req.Header.Set("User-Agent", defaultUserAgent) - - resp, err := s.httpClient.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return false - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024)) - if err != nil { - return false - } - - content := strings.ToLower(string(body)) - markers := []string{ - "file was deleted", - "file has been deleted", - "video was deleted", - "video has been deleted", - "video unavailable", - "file not found", - "this file does not exist", - "resource unavailable", - } - for _, marker := range markers { - if strings.Contains(content, marker) { - return false - } - } - - return true -} - -func (s *Service) fetchSkipSegments(ctx context.Context, malID int, episode string) []SkipSegment { - if malID <= 0 || strings.TrimSpace(episode) == "" { - return nil - } - - endpoint := fmt.Sprintf("https://api.aniskip.com/v1/skip-times/%s/%s?types=op&types=ed", url.PathEscape(strconv.Itoa(malID)), url.PathEscape(episode)) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil - } - req.Header.Set("User-Agent", defaultUserAgent) - - resp, err := s.httpClient.Do(req) - if err != nil { - return nil - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024)) - if err != nil { - return nil - } - - type resultItem struct { - SkipType string `json:"skip_type"` - Interval struct { - StartTime float64 `json:"start_time"` - EndTime float64 `json:"end_time"` - } `json:"interval"` - } - type apiResponse struct { - Found bool `json:"found"` - Result []resultItem `json:"results"` - } - - var parsed apiResponse - if err := json.Unmarshal(body, &parsed); err != nil { - return nil - } - - segments := make([]SkipSegment, 0, len(parsed.Result)) - for _, item := range parsed.Result { - if item.Interval.EndTime <= item.Interval.StartTime { - continue - } - - t := strings.ToLower(item.SkipType) - if t != "op" && t != "ed" { - continue - } - - segments = append(segments, SkipSegment{ - Type: t, - Start: item.Interval.StartTime, - End: item.Interval.EndTime, - }) - } - - return segments -} - -func (s *Service) ProxyStream(ctx context.Context, targetURL string, referer string, rangeHeader string) (int, http.Header, []byte, io.ReadCloser, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) - if err != nil { - return 0, nil, nil, nil, fmt.Errorf("invalid upstream url: %w", err) - } - - if referer != "" { - req.Header.Set("Referer", referer) - } - req.Header.Set("User-Agent", defaultUserAgent) - if rangeHeader != "" { - req.Header.Set("Range", rangeHeader) - } - - resp, err := s.httpClient.Do(req) - if err != nil { - return 0, nil, nil, nil, fmt.Errorf("upstream request failed: %w", err) - } - - if isM3U8(targetURL, resp.Header.Get("Content-Type")) { - defer resp.Body.Close() - body, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) - if readErr != nil { - return 0, nil, nil, nil, fmt.Errorf("read playlist failed: %w", readErr) - } - - rewritten, rewriteErr := rewritePlaylist(string(body), targetURL, referer) - if rewriteErr != nil { - return 0, nil, nil, nil, fmt.Errorf("rewrite playlist failed: %w", rewriteErr) - } - - headers := cloneHeaders(resp.Header) - headers.Set("Content-Type", "application/vnd.apple.mpegurl") - return resp.StatusCode, headers, []byte(rewritten), nil, nil - } - - headers := cloneHeaders(resp.Header) - return resp.StatusCode, headers, nil, resp.Body, nil -} - -func normalizeMode(raw string) string { - lower := strings.ToLower(strings.TrimSpace(raw)) - if lower == "sub" || lower == "dub" { - return lower - } - - return lower -} - -func toSubtitleItems(source StreamSource) []SubtitleItem { - items := make([]SubtitleItem, 0, len(source.Subtitles)) - for _, subtitle := range source.Subtitles { - targetURL := strings.TrimSpace(subtitle.URL) - if targetURL == "" { - continue - } - - items = append(items, SubtitleItem{ - Lang: strings.TrimSpace(subtitle.Lang), - URL: targetURL, - Referer: source.Referer, - }) - } - - return items -} - -func availableModes(modeSources map[string]ModeSource) []string { - ordered := make([]string, 0, len(modeSources)) - if _, ok := modeSources["dub"]; ok { - ordered = append(ordered, "dub") - } - if _, ok := modeSources["sub"]; ok { - ordered = append(ordered, "sub") - } - - extra := make([]string, 0) - for mode := range modeSources { - if mode == "dub" || mode == "sub" { - continue - } - extra = append(extra, mode) - } - sort.Strings(extra) - - return append(ordered, extra...) -} - -func selectInitialMode(requestedMode string, modeSources map[string]ModeSource) string { - normalizedRequested := normalizeMode(requestedMode) - if normalizedRequested != "" { - if _, ok := modeSources[normalizedRequested]; ok { - return normalizedRequested - } - } - - if _, ok := modeSources["dub"]; ok { - return "dub" - } - if _, ok := modeSources["sub"]; ok { - return "sub" - } - - for mode := range modeSources { - return mode - } - - return "dub" -} - -func rankSources(sources []StreamSource, quality string) ([]sourceScore, error) { - filtered := make([]StreamSource, 0, len(sources)) - seen := make(map[string]struct{}) - - for _, source := range sources { - if source.URL == "" { - continue - } - if _, exists := seen[source.URL]; exists { - continue - } - seen[source.URL] = struct{}{} - filtered = append(filtered, source) - } - - if len(filtered) == 0 { - return nil, errors.New("no playable sources available") - } - - targetQuality := normalizeQuality(quality) - scored := make([]sourceScore, 0, len(filtered)) - for _, source := range filtered { - typeScore := sourceTypePriority(source.Type) - providerScore := providerPriority(source.Provider) - qualityScore := sourceQualityPriority(source.Quality, targetQuality) - refererScore := 0 - if source.Referer != "" { - refererScore = 20 - } - - total := typeScore + providerScore + qualityScore + refererScore - scored = append(scored, sourceScore{ - source: source, - total: total, - typeScore: typeScore, - providerScore: providerScore, - qualityScore: qualityScore, - refererScore: refererScore, - }) - } - - sort.SliceStable(scored, func(i int, j int) bool { - return scored[i].total > scored[j].total - }) - - return scored, nil -} - -func normalizeQuality(quality string) string { - lower := strings.ToLower(strings.TrimSpace(quality)) - if lower == "" { - return "best" - } - - return lower -} - -func sourceTypePriority(sourceType string) int { - switch strings.ToLower(sourceType) { - case "mp4": - return 500 - case "m3u8": - return 450 - case "unknown": - return 300 - case "embed": - return 100 - default: - return 200 - } -} - -func providerPriority(provider string) int { - switch strings.ToLower(provider) { - case "s-mp4": - return 120 - case "default": - return 115 - case "luf-mp4": - return 110 - case "vid-mp4": - return 105 - case "yt-mp4": - return 100 - case "mp4": - return 95 - case "uv-mp4": - return 90 - case "hls": - return 80 - case "sw": - return 40 - case "ok": - return 35 - case "ss-hls": - return 30 - default: - return 60 - } -} - -func sourceQualityPriority(sourceQuality string, targetQuality string) int { - qualityValue := parseQualityValue(sourceQuality) - - switch targetQuality { - case "best": - return qualityValue - case "worst": - return -qualityValue - default: - if qualityMatches(sourceQuality, targetQuality) { - return 2000 + qualityValue - } - - return -300 + qualityValue - } -} - -func parseQualityValue(rawQuality string) int { - lower := strings.ToLower(rawQuality) - var digits strings.Builder - - for _, char := range lower { - if char >= '0' && char <= '9' { - digits.WriteRune(char) - continue - } - if digits.Len() > 0 { - break - } - } - - if digits.Len() > 0 { - value, err := strconv.Atoi(digits.String()) - if err == nil { - return value - } - } - - if lower == "auto" { - return 240 - } - - return 0 -} - -func qualityMatches(sourceQuality string, targetQuality string) bool { - sourceLower := strings.ToLower(sourceQuality) - targetLower := strings.ToLower(targetQuality) - - if sourceLower == "" { - return false - } - - if strings.Contains(sourceLower, targetLower) { - return true - } - - sourceDigits := extractDigits(sourceLower) - targetDigits := extractDigits(targetLower) - - return sourceDigits != "" && sourceDigits == targetDigits -} - -func extractDigits(value string) string { - var digits strings.Builder - for _, char := range value { - if char >= '0' && char <= '9' { - digits.WriteRune(char) - continue - } - if digits.Len() > 0 { - break - } - } - - return digits.String() -} - -func normalizeSourceTypeFromProbe(source StreamSource, contentType string) StreamSource { - lower := strings.ToLower(contentType) - if strings.Contains(lower, "video/mp4") { - source.Type = "mp4" - return source - } - - if strings.Contains(lower, "mpegurl") { - source.Type = "m3u8" - return source - } - - return source -} - -func isLikelyMP4(payload []byte) bool { - if len(payload) < 12 { - return false - } - - return bytes.Equal(payload[4:8], []byte("ftyp")) -} - -func isLikelyM3U8(payload []byte) bool { - trimmed := strings.TrimSpace(string(payload)) - return strings.HasPrefix(trimmed, "#EXTM3U") -} - -func isM3U8(targetURL string, contentType string) bool { - lowerURL := strings.ToLower(targetURL) - lowerType := strings.ToLower(contentType) - if strings.Contains(lowerURL, ".m3u8") { - return true - } - - return strings.Contains(lowerType, "application/vnd.apple.mpegurl") || strings.Contains(lowerType, "application/x-mpegurl") -} - -func rewritePlaylist(content string, baseURL string, referer string) (string, error) { - base, err := url.Parse(baseURL) - if err != nil { - return "", err - } - - var out strings.Builder - scanner := bufio.NewScanner(strings.NewReader(content)) - for scanner.Scan() { - line := scanner.Text() - trimmed := strings.TrimSpace(line) - if trimmed == "" || strings.HasPrefix(trimmed, "#") { - out.WriteString(line) - out.WriteString("\n") - continue - } - - relativeURL, parseErr := url.Parse(trimmed) - if parseErr != nil { - out.WriteString(line) - out.WriteString("\n") - continue - } - - absolute := base.ResolveReference(relativeURL).String() - proxied := "/watch/proxy/segment?u=" + url.QueryEscape(absolute) - if referer != "" { - proxied += "&r=" + url.QueryEscape(referer) - } - - out.WriteString(proxied) - out.WriteString("\n") - } - - if err := scanner.Err(); err != nil { - return "", err - } - - return out.String(), nil -} - -func cloneHeaders(src http.Header) http.Header { - dst := make(http.Header) - for key, values := range src { - lower := strings.ToLower(key) - if lower == "connection" || lower == "transfer-encoding" || lower == "keep-alive" || lower == "proxy-authenticate" || lower == "proxy-authorization" || lower == "te" || lower == "trailers" || lower == "upgrade" { - continue - } - - for _, value := range values { - dst.Add(key, value) - } - } - - return dst -} diff --git a/internal/features/playback/service_base.go b/internal/features/playback/service_base.go new file mode 100644 index 0000000..8217c4f --- /dev/null +++ b/internal/features/playback/service_base.go @@ -0,0 +1,391 @@ +package playback + +import ( + "context" + "database/sql" + "errors" + "fmt" + "mal/internal/database" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +const ( + showResolutionCacheTTL = 12 * time.Hour + playbackDataCacheTTL = 2 * time.Minute + providerProbeTimeout = 3 * time.Second +) + +type Service struct { + allAnimeClient *allAnimeClient + httpClient *http.Client + sqlDB *sql.DB + db database.Querier + proxyTokens *proxyTokenSigner + proxyHostMu sync.RWMutex + proxyHostCache map[string]proxyHostCacheItem + + cacheMu sync.RWMutex + showResolution map[int]showResolutionCacheItem + playbackDataCache map[string]playbackDataCacheItem +} + +type Config struct { + ProxyTokenSecret string +} + +type sourceScore struct { + source StreamSource + total int + typeScore int + providerScore int + qualityScore int + refererScore int +} + +type showResolutionCacheItem struct { + ShowID string + Title string + ExpiresAt time.Time +} + +type playbackDataCacheItem struct { + Data playbackBaseData + ExpiresAt time.Time +} + +type playbackBaseData struct { + Title string + AvailableModes []string + ModeSources map[string]ModeSource + Segments []SkipSegment +} + +type modeSourceResult struct { + Mode string + Source ModeSource + OK bool +} + +type searchModeResult struct { + Mode string + Results []searchResult + Err error +} + +type directProbeResult struct { + Playable bool + ContentType string +} + +type proxyHostCacheItem struct { + Allowed bool + ExpiresAt time.Time +} + +type userPlaybackState struct { + CurrentStatus string + StartTimeSeconds float64 +} + +func NewService(db database.Querier, sqlDB *sql.DB, cfg Config) *Service { + proxyTokens, err := newProxyTokenSigner(cfg.ProxyTokenSecret) + if err != nil { + panic(fmt.Sprintf("failed to initialize proxy token signer: %v", err)) + } + + return &Service{ + allAnimeClient: newAllAnimeClient(), + httpClient: &http.Client{Timeout: 12 * time.Second}, + sqlDB: sqlDB, + db: db, + proxyTokens: proxyTokens, + proxyHostCache: make(map[string]proxyHostCacheItem), + showResolution: make(map[int]showResolutionCacheItem), + playbackDataCache: make(map[string]playbackDataCacheItem), + } +} + +func (s *Service) BuildWatchPageData(ctx context.Context, malID int, titleCandidates []string, episode string, mode string, userID string) (WatchPageData, error) { + if malID <= 0 { + return WatchPageData{}, errors.New("invalid mal id") + } + + normalizedMode := normalizeMode(mode) + if normalizedMode == "" { + normalizedMode = "dub" + } + + normalizedEpisode := strings.TrimSpace(episode) + if normalizedEpisode == "" { + normalizedEpisode = "1" + } + + userStateCh := s.fetchUserPlaybackStateAsync(ctx, userID, malID, normalizedEpisode) + + cacheKey := playbackDataCacheKey(malID, normalizedEpisode) + baseData, cacheHit := s.getPlaybackBaseDataCache(cacheKey) + if !cacheHit { + showID, resolvedTitle, err := s.resolveShowCached(ctx, malID, titleCandidates) + if err != nil { + return WatchPageData{}, err + } + + modeSources, segments := s.fetchPlaybackSourcesAndSegments(ctx, showID, malID, normalizedEpisode) + if len(modeSources) == 0 { + return WatchPageData{}, errors.New("no direct playable sources available") + } + + watchTitle := strings.TrimSpace(resolvedTitle) + if watchTitle == "" { + watchTitle = firstNonEmptyTitle(titleCandidates) + } + if watchTitle == "" { + watchTitle = fmt.Sprintf("MAL #%d", malID) + } + + baseData = playbackBaseData{ + Title: watchTitle, + AvailableModes: availableModes(modeSources), + ModeSources: modeSources, + Segments: segments, + } + + s.setPlaybackBaseDataCache(cacheKey, baseData) + } + + initialMode := selectInitialMode(normalizedMode, baseData.ModeSources) + + clientModeSources, err := s.buildClientModeSources(baseData.ModeSources) + if err != nil { + return WatchPageData{}, err + } + + if _, ok := clientModeSources[initialMode]; !ok { + return WatchPageData{}, errors.New("stream mode unavailable") + } + + userState := userPlaybackState{} + if userStateCh != nil { + userState = <-userStateCh + } + + return WatchPageData{ + MalID: malID, + Title: baseData.Title, + CurrentEpisode: normalizedEpisode, + StartTimeSeconds: userState.StartTimeSeconds, + CurrentStatus: userState.CurrentStatus, + InitialMode: initialMode, + AvailableModes: cloneStringSlice(baseData.AvailableModes), + ModeSources: clientModeSources, + Segments: cloneSegments(baseData.Segments), + }, nil +} + +func playbackDataCacheKey(malID int, episode string) string { + return fmt.Sprintf("%d:%s", malID, episode) +} + +func (s *Service) fetchUserPlaybackStateAsync(ctx context.Context, userID string, malID int, episode string) <-chan userPlaybackState { + if userID == "" || s.db == nil { + return nil + } + + resultCh := make(chan userPlaybackState, 1) + go func() { + state := userPlaybackState{} + + entry, err := s.db.GetWatchListEntry(ctx, database.GetWatchListEntryParams{ + UserID: userID, + AnimeID: int64(malID), + }) + if err == nil { + state.CurrentStatus = entry.Status + if entry.CurrentEpisode.Valid && strconv.FormatInt(entry.CurrentEpisode.Int64, 10) == episode && entry.CurrentTimeSeconds > 0 { + state.StartTimeSeconds = entry.CurrentTimeSeconds + } + } + + if state.StartTimeSeconds <= 0 { + continueEntry, continueErr := s.db.GetContinueWatchingEntry(ctx, database.GetContinueWatchingEntryParams{ + UserID: userID, + AnimeID: int64(malID), + }) + if continueErr == nil && continueEntry.CurrentEpisode.Valid && strconv.FormatInt(continueEntry.CurrentEpisode.Int64, 10) == episode && continueEntry.CurrentTimeSeconds > 0 { + state.StartTimeSeconds = continueEntry.CurrentTimeSeconds + } + } + + resultCh <- state + }() + + return resultCh +} + +func (s *Service) getPlaybackBaseDataCache(key string) (playbackBaseData, bool) { + now := time.Now() + + s.cacheMu.RLock() + item, ok := s.playbackDataCache[key] + s.cacheMu.RUnlock() + if !ok { + return playbackBaseData{}, false + } + + if now.After(item.ExpiresAt) { + s.cacheMu.Lock() + current, exists := s.playbackDataCache[key] + if exists && time.Now().After(current.ExpiresAt) { + delete(s.playbackDataCache, key) + } + s.cacheMu.Unlock() + return playbackBaseData{}, false + } + + return clonePlaybackBaseData(item.Data), true +} + +func (s *Service) setPlaybackBaseDataCache(key string, data playbackBaseData) { + s.cacheMu.Lock() + s.playbackDataCache[key] = playbackDataCacheItem{ + Data: clonePlaybackBaseData(data), + ExpiresAt: time.Now().Add(playbackDataCacheTTL), + } + s.cacheMu.Unlock() +} + +func (s *Service) resolveShowCached(ctx context.Context, malID int, titleCandidates []string) (string, string, error) { + now := time.Now() + + s.cacheMu.RLock() + item, ok := s.showResolution[malID] + s.cacheMu.RUnlock() + + if ok && now.Before(item.ExpiresAt) && strings.TrimSpace(item.ShowID) != "" { + return item.ShowID, item.Title, nil + } + + showID, resolvedTitle, err := s.resolveShow(ctx, malID, titleCandidates) + if err != nil { + return "", "", err + } + + s.cacheMu.Lock() + s.showResolution[malID] = showResolutionCacheItem{ + ShowID: showID, + Title: resolvedTitle, + ExpiresAt: time.Now().Add(showResolutionCacheTTL), + } + s.cacheMu.Unlock() + + return showID, resolvedTitle, nil +} + +func (s *Service) fetchPlaybackSourcesAndSegments(ctx context.Context, showID string, malID int, episode string) (map[string]ModeSource, []SkipSegment) { + modeCh := make(chan modeSourceResult, 2) + probeCache := make(map[string]directProbeResult) + probeCacheMu := sync.Mutex{} + + for _, mode := range []string{"dub", "sub"} { + modeValue := mode + go func() { + resolved, err := s.resolveModeSourceWithCache(ctx, showID, episode, modeValue, "best", probeCache, &probeCacheMu) + if err != nil { + modeCh <- modeSourceResult{Mode: modeValue, OK: false} + return + } + + if strings.ToLower(resolved.Type) == "embed" { + modeCh <- modeSourceResult{Mode: modeValue, OK: false} + return + } + + modeCh <- modeSourceResult{ + Mode: modeValue, + Source: ModeSource{ + URL: resolved.URL, + Referer: resolved.Referer, + Subtitles: toSubtitleItems(resolved), + }, + OK: true, + } + }() + } + + segmentsCh := make(chan []SkipSegment, 1) + go func() { + segmentsCh <- s.fetchSkipSegments(ctx, malID, episode) + }() + + modeSources := make(map[string]ModeSource) + for range 2 { + result := <-modeCh + if !result.OK { + continue + } + modeSources[result.Mode] = result.Source + } + + segments := <-segmentsCh + return modeSources, segments +} + +func clonePlaybackBaseData(data playbackBaseData) playbackBaseData { + return playbackBaseData{ + Title: data.Title, + AvailableModes: cloneStringSlice(data.AvailableModes), + ModeSources: cloneModeSources(data.ModeSources), + Segments: cloneSegments(data.Segments), + } +} + +func cloneStringSlice(items []string) []string { + if len(items) == 0 { + return nil + } + + cloned := make([]string, len(items)) + copy(cloned, items) + return cloned +} + +func cloneModeSources(modeSources map[string]ModeSource) map[string]ModeSource { + if len(modeSources) == 0 { + return nil + } + + cloned := make(map[string]ModeSource, len(modeSources)) + for mode, source := range modeSources { + cloned[mode] = ModeSource{ + URL: source.URL, + Referer: source.Referer, + Subtitles: cloneSubtitleItems(source.Subtitles), + } + } + + return cloned +} + +func cloneSubtitleItems(items []SubtitleItem) []SubtitleItem { + if len(items) == 0 { + return nil + } + + cloned := make([]SubtitleItem, len(items)) + copy(cloned, items) + return cloned +} + +func cloneSegments(segments []SkipSegment) []SkipSegment { + if len(segments) == 0 { + return nil + } + + cloned := make([]SkipSegment, len(segments)) + copy(cloned, segments) + return cloned +} diff --git a/internal/features/playback/service_http.go b/internal/features/playback/service_http.go new file mode 100644 index 0000000..7515a0e --- /dev/null +++ b/internal/features/playback/service_http.go @@ -0,0 +1,77 @@ +package playback + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" +) + +func (s *Service) fetchSkipSegments(ctx context.Context, malID int, episode string) []SkipSegment { + if malID <= 0 || strings.TrimSpace(episode) == "" { + return nil + } + + endpoint := fmt.Sprintf("https://api.aniskip.com/v1/skip-times/%s/%s?types=op&types=ed", url.PathEscape(strconv.Itoa(malID)), url.PathEscape(episode)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil + } + req.Header.Set("User-Agent", defaultUserAgent) + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024)) + if err != nil { + return nil + } + + type resultItem struct { + SkipType string `json:"skip_type"` + Interval struct { + StartTime float64 `json:"start_time"` + EndTime float64 `json:"end_time"` + } `json:"interval"` + } + type apiResponse struct { + Found bool `json:"found"` + Result []resultItem `json:"results"` + } + + var parsed apiResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil + } + + segments := make([]SkipSegment, 0, len(parsed.Result)) + for _, item := range parsed.Result { + if item.Interval.EndTime <= item.Interval.StartTime { + continue + } + + t := strings.ToLower(item.SkipType) + if t != "op" && t != "ed" { + continue + } + + segments = append(segments, SkipSegment{ + Type: t, + Start: item.Interval.StartTime, + End: item.Interval.EndTime, + }) + } + + return segments +} diff --git a/internal/features/playback/service_proxy.go b/internal/features/playback/service_proxy.go new file mode 100644 index 0000000..f8c3f85 --- /dev/null +++ b/internal/features/playback/service_proxy.go @@ -0,0 +1,75 @@ +package playback + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" +) + +func (s *Service) ProxyStream(ctx context.Context, targetURL string, referer string, rangeHeader string) (int, http.Header, []byte, io.ReadCloser, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) + if err != nil { + return 0, nil, nil, nil, fmt.Errorf("invalid upstream url: %w", err) + } + + if referer != "" { + req.Header.Set("Referer", referer) + } + req.Header.Set("User-Agent", defaultUserAgent) + if rangeHeader != "" { + req.Header.Set("Range", rangeHeader) + } + + resp, err := s.httpClient.Do(req) + if err != nil { + return 0, nil, nil, nil, fmt.Errorf("upstream request failed: %w", err) + } + + if isM3U8(targetURL, resp.Header.Get("Content-Type")) { + defer resp.Body.Close() + body, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) + if readErr != nil { + return 0, nil, nil, nil, fmt.Errorf("read playlist failed: %w", readErr) + } + + rewritten, rewriteErr := s.rewritePlaylistWithTokens(ctx, string(body), targetURL, referer) + if rewriteErr != nil { + return 0, nil, nil, nil, fmt.Errorf("rewrite playlist failed: %w", rewriteErr) + } + + headers := cloneHeaders(resp.Header) + headers.Set("Content-Type", "application/vnd.apple.mpegurl") + return resp.StatusCode, headers, []byte(rewritten), nil, nil + } + + headers := cloneHeaders(resp.Header) + return resp.StatusCode, headers, nil, resp.Body, nil +} + +func isM3U8(targetURL string, contentType string) bool { + lowerURL := strings.ToLower(targetURL) + lowerType := strings.ToLower(contentType) + if strings.Contains(lowerURL, ".m3u8") { + return true + } + + return strings.Contains(lowerType, "application/vnd.apple.mpegurl") || strings.Contains(lowerType, "application/x-mpegurl") +} + +func cloneHeaders(src http.Header) http.Header { + dst := make(http.Header) + for key, values := range src { + lower := strings.ToLower(key) + if lower == "connection" || lower == "transfer-encoding" || lower == "keep-alive" || lower == "proxy-authenticate" || lower == "proxy-authorization" || lower == "te" || lower == "trailers" || lower == "upgrade" { + continue + } + + for _, value := range values { + dst.Add(key, value) + } + } + + return dst +} diff --git a/internal/features/playback/service_ranking.go b/internal/features/playback/service_ranking.go new file mode 100644 index 0000000..6435752 --- /dev/null +++ b/internal/features/playback/service_ranking.go @@ -0,0 +1,216 @@ +package playback + +import ( + "bytes" + "errors" + "sort" + "strconv" + "strings" +) + +func rankSources(sources []StreamSource, quality string) ([]sourceScore, error) { + filtered := make([]StreamSource, 0, len(sources)) + seen := make(map[string]struct{}) + + for _, source := range sources { + if source.URL == "" { + continue + } + if _, exists := seen[source.URL]; exists { + continue + } + seen[source.URL] = struct{}{} + filtered = append(filtered, source) + } + + if len(filtered) == 0 { + return nil, errors.New("no playable sources available") + } + + targetQuality := normalizeQuality(quality) + scored := make([]sourceScore, 0, len(filtered)) + for _, source := range filtered { + typeScore := sourceTypePriority(source.Type) + providerScore := providerPriority(source.Provider) + qualityScore := sourceQualityPriority(source.Quality, targetQuality) + refererScore := 0 + if source.Referer != "" { + refererScore = 20 + } + + total := typeScore + providerScore + qualityScore + refererScore + scored = append(scored, sourceScore{ + source: source, + total: total, + typeScore: typeScore, + providerScore: providerScore, + qualityScore: qualityScore, + refererScore: refererScore, + }) + } + + sort.SliceStable(scored, func(i int, j int) bool { + return scored[i].total > scored[j].total + }) + + return scored, nil +} + +func normalizeQuality(quality string) string { + lower := strings.ToLower(strings.TrimSpace(quality)) + if lower == "" { + return "best" + } + + return lower +} + +func sourceTypePriority(sourceType string) int { + switch strings.ToLower(sourceType) { + case "mp4": + return 500 + case "m3u8": + return 450 + case "unknown": + return 300 + case "embed": + return 100 + default: + return 200 + } +} + +func providerPriority(provider string) int { + switch strings.ToLower(provider) { + case "s-mp4": + return 120 + case "default": + return 115 + case "luf-mp4": + return 110 + case "vid-mp4": + return 105 + case "yt-mp4": + return 100 + case "mp4": + return 95 + case "uv-mp4": + return 90 + case "hls": + return 80 + case "sw": + return 40 + case "ok": + return 35 + case "ss-hls": + return 30 + default: + return 60 + } +} + +func sourceQualityPriority(sourceQuality string, targetQuality string) int { + qualityValue := parseQualityValue(sourceQuality) + + switch targetQuality { + case "best": + return qualityValue + case "worst": + return -qualityValue + default: + if qualityMatches(sourceQuality, targetQuality) { + return 2000 + qualityValue + } + + return -300 + qualityValue + } +} + +func parseQualityValue(rawQuality string) int { + lower := strings.ToLower(rawQuality) + var digits strings.Builder + + for _, char := range lower { + if char >= '0' && char <= '9' { + digits.WriteRune(char) + continue + } + if digits.Len() > 0 { + break + } + } + + if digits.Len() > 0 { + value, err := strconv.Atoi(digits.String()) + if err == nil { + return value + } + } + + if lower == "auto" { + return 240 + } + + return 0 +} + +func qualityMatches(sourceQuality string, targetQuality string) bool { + sourceLower := strings.ToLower(sourceQuality) + targetLower := strings.ToLower(targetQuality) + + if sourceLower == "" { + return false + } + + if strings.Contains(sourceLower, targetLower) { + return true + } + + sourceDigits := extractDigits(sourceLower) + targetDigits := extractDigits(targetLower) + + return sourceDigits != "" && sourceDigits == targetDigits +} + +func extractDigits(value string) string { + var digits strings.Builder + for _, char := range value { + if char >= '0' && char <= '9' { + digits.WriteRune(char) + continue + } + if digits.Len() > 0 { + break + } + } + + return digits.String() +} + +func normalizeSourceTypeFromProbe(source StreamSource, contentType string) StreamSource { + lower := strings.ToLower(contentType) + if strings.Contains(lower, "video/mp4") { + source.Type = "mp4" + return source + } + + if strings.Contains(lower, "mpegurl") { + source.Type = "m3u8" + return source + } + + return source +} + +func isLikelyMP4(payload []byte) bool { + if len(payload) < 12 { + return false + } + + return bytes.Equal(payload[4:8], []byte("ftyp")) +} + +func isLikelyM3U8(payload []byte) bool { + trimmed := strings.TrimSpace(string(payload)) + return strings.HasPrefix(trimmed, "#EXTM3U") +} diff --git a/internal/features/playback/service_resolution.go b/internal/features/playback/service_resolution.go new file mode 100644 index 0000000..f166eea --- /dev/null +++ b/internal/features/playback/service_resolution.go @@ -0,0 +1,174 @@ +package playback + +import ( + "context" + "errors" + "sort" + "strconv" + "strings" + "sync" +) + +func (s *Service) resolveShow(ctx context.Context, malID int, titleCandidates []string) (string, string, error) { + malText := strconv.Itoa(malID) + modeCandidates := []string{"sub", "dub"} + queries := buildTitleSearchQueries(titleCandidates) + + for _, query := range queries { + resultsByMode := s.searchShowResultsByMode(ctx, query, modeCandidates) + + for _, mode := range modeCandidates { + for _, result := range resultsByMode[mode] { + if strings.TrimSpace(result.MalID) == malText && strings.TrimSpace(result.ID) != "" { + return result.ID, result.Name, nil + } + } + } + + for _, mode := range modeCandidates { + results := resultsByMode[mode] + if len(results) == 0 { + continue + } + + best := results[0] + if strings.TrimSpace(best.ID) != "" { + return best.ID, best.Name, nil + } + } + } + + return "", "", errors.New("unable to resolve allanime show") +} + +func (s *Service) searchShowResultsByMode(ctx context.Context, query string, modeCandidates []string) map[string][]searchResult { + resultsByMode := make(map[string][]searchResult, len(modeCandidates)) + searchCh := make(chan searchModeResult, len(modeCandidates)) + + var wg sync.WaitGroup + for _, mode := range modeCandidates { + modeValue := mode + wg.Add(1) + go func() { + defer wg.Done() + results, err := s.allAnimeClient.Search(ctx, query, modeValue) + searchCh <- searchModeResult{Mode: modeValue, Results: results, Err: err} + }() + } + + wg.Wait() + close(searchCh) + + for result := range searchCh { + if result.Err != nil { + continue + } + + resultsByMode[result.Mode] = result.Results + } + + return resultsByMode +} + +func buildTitleSearchQueries(titleCandidates []string) []string { + queries := make([]string, 0, len(titleCandidates)*4) + seen := make(map[string]struct{}) + + add := func(raw string) { + normalized := normalizeSearchQuery(raw) + if normalized == "" { + return + } + + key := strings.ToLower(normalized) + if _, exists := seen[key]; exists { + return + } + + seen[key] = struct{}{} + queries = append(queries, normalized) + } + + for _, candidate := range titleCandidates { + normalized := normalizeSearchQuery(candidate) + if normalized == "" { + continue + } + + add(normalized) + add(strings.ReplaceAll(normalized, "+", " ")) + + withoutApostrophes := strings.NewReplacer("'", "", "’", "", "`", "").Replace(normalized) + add(withoutApostrophes) + add(strings.ReplaceAll(withoutApostrophes, "+", " ")) + } + + return queries +} + +func normalizeSearchQuery(raw string) string { + return strings.Join(strings.Fields(strings.TrimSpace(raw)), " ") +} + +func firstNonEmptyTitle(values []string) string { + for _, value := range values { + normalized := strings.TrimSpace(value) + if normalized != "" { + return normalized + } + } + + return "" +} + +func normalizeMode(raw string) string { + lower := strings.ToLower(strings.TrimSpace(raw)) + if lower == "sub" || lower == "dub" { + return lower + } + + return lower +} + +func availableModes(modeSources map[string]ModeSource) []string { + ordered := make([]string, 0, len(modeSources)) + if _, ok := modeSources["dub"]; ok { + ordered = append(ordered, "dub") + } + if _, ok := modeSources["sub"]; ok { + ordered = append(ordered, "sub") + } + + extra := make([]string, 0) + for mode := range modeSources { + if mode == "dub" || mode == "sub" { + continue + } + extra = append(extra, mode) + } + sort.Strings(extra) + + return append(ordered, extra...) +} + +func selectInitialMode(requestedMode string, modeSources map[string]ModeSource) string { + normalizedRequested := normalizeMode(requestedMode) + if normalizedRequested != "" { + if _, ok := modeSources[normalizedRequested]; ok { + return normalizedRequested + } + } + + if _, ok := modeSources["dub"]; ok { + return "dub" + } + if _, ok := modeSources["sub"]; ok { + return "sub" + } + + for mode := range modeSources { + return mode + } + + return "dub" +} diff --git a/internal/features/playback/service_sources.go b/internal/features/playback/service_sources.go new file mode 100644 index 0000000..8dcb8a1 --- /dev/null +++ b/internal/features/playback/service_sources.go @@ -0,0 +1,257 @@ +package playback + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "sync" +) + +func (s *Service) resolveModeSource(ctx context.Context, showID string, episode string, mode string, quality string) (StreamSource, error) { + sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode) + if err != nil { + return StreamSource{}, err + } + + ranked, err := rankSources(sources, quality) + if err != nil { + return StreamSource{}, err + } + + selected, _, err := s.choosePlaybackSource(ctx, ranked) + if err != nil { + return StreamSource{}, err + } + + return selected, nil +} + +func (s *Service) resolveModeSourceWithCache( + ctx context.Context, + showID string, + episode string, + mode string, + quality string, + probeCache map[string]directProbeResult, + probeCacheMu *sync.Mutex, +) (StreamSource, error) { + sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode) + if err != nil { + return StreamSource{}, err + } + + ranked, err := rankSources(sources, quality) + if err != nil { + return StreamSource{}, err + } + + selected, _, err := s.choosePlaybackSourceWithCache(ctx, ranked, probeCache, probeCacheMu) + if err != nil { + return StreamSource{}, err + } + + return selected, nil +} + +func (s *Service) choosePlaybackSource(ctx context.Context, ranked []sourceScore) (StreamSource, string, error) { + if len(ranked) == 0 { + return StreamSource{}, "", errors.New("no ranked sources available") + } + + embedCandidates := make([]StreamSource, 0) + for _, candidate := range ranked { + source := candidate.source + sourceType := strings.ToLower(source.Type) + + switch sourceType { + case "mp4", "m3u8": + return source, "direct-media", nil + case "embed": + embedCandidates = append(embedCandidates, source) + default: + playable, contentType := s.probeDirectMedia(ctx, source) + if playable { + return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil + } + } + } + + for _, embed := range embedCandidates { + if s.probeEmbedSource(ctx, embed) { + return embed, "embed-probed", nil + } + } + + if len(embedCandidates) > 0 { + return embedCandidates[0], "embed-fallback", nil + } + + return ranked[0].source, "ranked-fallback", nil +} + +func (s *Service) choosePlaybackSourceWithCache( + ctx context.Context, + ranked []sourceScore, + probeCache map[string]directProbeResult, + probeCacheMu *sync.Mutex, +) (StreamSource, string, error) { + if len(ranked) == 0 { + return StreamSource{}, "", errors.New("no ranked sources available") + } + + embedCandidates := make([]StreamSource, 0) + for _, candidate := range ranked { + source := candidate.source + sourceType := strings.ToLower(source.Type) + + switch sourceType { + case "mp4", "m3u8": + return source, "direct-media", nil + case "embed": + embedCandidates = append(embedCandidates, source) + default: + playable, contentType := s.probeDirectMediaCached(ctx, source, probeCache, probeCacheMu) + if playable { + return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil + } + } + } + + for _, embed := range embedCandidates { + if s.probeEmbedSource(ctx, embed) { + return embed, "embed-probed", nil + } + } + + if len(embedCandidates) > 0 { + return embedCandidates[0], "embed-fallback", nil + } + + return ranked[0].source, "ranked-fallback", nil +} + +func (s *Service) probeDirectMediaCached( + ctx context.Context, + source StreamSource, + probeCache map[string]directProbeResult, + probeCacheMu *sync.Mutex, +) (bool, string) { + cacheKey := strings.TrimSpace(source.URL) + if cacheKey == "" { + return s.probeDirectMedia(ctx, source) + } + + probeCacheMu.Lock() + cached, ok := probeCache[cacheKey] + probeCacheMu.Unlock() + if ok { + return cached.Playable, cached.ContentType + } + + playable, contentType := s.probeDirectMedia(ctx, source) + + probeCacheMu.Lock() + probeCache[cacheKey] = directProbeResult{Playable: playable, ContentType: contentType} + probeCacheMu.Unlock() + + return playable, contentType +} + +func (s *Service) probeDirectMedia(ctx context.Context, source StreamSource) (bool, string) { + probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil) + if err != nil { + return false, "" + } + + if source.Referer != "" { + req.Header.Set("Referer", source.Referer) + } + req.Header.Set("User-Agent", defaultUserAgent) + req.Header.Set("Range", "bytes=0-4095") + + resp, err := s.httpClient.Do(req) + if err != nil { + return false, "" + } + defer resp.Body.Close() + + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + if strings.Contains(contentType, "video/") || strings.Contains(contentType, "mpegurl") { + return true, contentType + } + + prefix, err := io.ReadAll(io.LimitReader(resp.Body, 4096)) + if err == nil { + if isLikelyM3U8(prefix) { + return true, "application/vnd.apple.mpegurl" + } + if isLikelyMP4(prefix) { + return true, "video/mp4" + } + } + + finalURL := "" + if resp.Request != nil && resp.Request.URL != nil { + finalURL = strings.ToLower(resp.Request.URL.String()) + } + + if strings.Contains(finalURL, ".mp4") || strings.Contains(finalURL, ".m3u8") { + return true, contentType + } + + return false, contentType +} + +func (s *Service) probeEmbedSource(ctx context.Context, source StreamSource) bool { + probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil) + if err != nil { + return false + } + + if source.Referer != "" { + req.Header.Set("Referer", source.Referer) + } + req.Header.Set("User-Agent", defaultUserAgent) + + resp, err := s.httpClient.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + if resp.StatusCode >= http.StatusBadRequest { + return false + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024)) + if err != nil { + return false + } + + content := strings.ToLower(string(body)) + markers := []string{ + "file was deleted", + "file has been deleted", + "video was deleted", + "video has been deleted", + "video unavailable", + "file not found", + "this file does not exist", + "resource unavailable", + } + for _, marker := range markers { + if strings.Contains(content, marker) { + return false + } + } + + return true +} diff --git a/internal/features/playback/service_utils.go b/internal/features/playback/service_utils.go new file mode 100644 index 0000000..1e749d9 --- /dev/null +++ b/internal/features/playback/service_utils.go @@ -0,0 +1,23 @@ +package playback + +import ( + "strings" +) + +func toSubtitleItems(source StreamSource) []SubtitleItem { + items := make([]SubtitleItem, 0, len(source.Subtitles)) + for _, subtitle := range source.Subtitles { + targetURL := strings.TrimSpace(subtitle.URL) + if targetURL == "" { + continue + } + + items = append(items, SubtitleItem{ + Lang: strings.TrimSpace(subtitle.Lang), + URL: targetURL, + Referer: source.Referer, + }) + } + + return items +} diff --git a/internal/features/playback/types.go b/internal/features/playback/types.go index 8ac7159..56d7a52 100644 --- a/internal/features/playback/types.go +++ b/internal/features/playback/types.go @@ -15,15 +15,17 @@ type Subtitle struct { } type ModeSource struct { - URL string `json:"url"` - Referer string `json:"referer"` + URL string `json:"url,omitempty"` + Referer string `json:"referer,omitempty"` + Token string `json:"token"` Subtitles []SubtitleItem `json:"subtitles"` } type SubtitleItem struct { Lang string `json:"lang"` - URL string `json:"url"` - Referer string `json:"referer"` + URL string `json:"url,omitempty"` + Referer string `json:"referer,omitempty"` + Token string `json:"token"` } type SkipSegment struct { diff --git a/internal/features/watchlist/handler.go b/internal/features/watchlist/handler.go index 3e4dbc5..a4d89ff 100644 --- a/internal/features/watchlist/handler.go +++ b/internal/features/watchlist/handler.go @@ -1,7 +1,6 @@ package watchlist import ( - "database/sql" "encoding/json" "errors" "log" @@ -202,7 +201,7 @@ func (h *Handler) HandleContinueWatching(w http.ResponseWriter, r *http.Request) return } - entries, err := h.svc.db.GetContinueWatchingEntries(r.Context(), user.ID) + entries, err := h.svc.GetContinueWatching(r.Context(), user.ID) if err != nil { log.Printf("continue watching fetch failed: user_id=%s err=%v", user.ID, err) http.Error(w, "failed to fetch continue watching", http.StatusInternalServerError) @@ -231,27 +230,12 @@ func (h *Handler) HandleDeleteContinueWatching(w http.ResponseWriter, r *http.Re return } - err = h.svc.db.DeleteContinueWatchingEntry(r.Context(), database.DeleteContinueWatchingEntryParams{ - UserID: user.ID, - AnimeID: animeID, - }) - if err != nil { + if err := h.svc.DeleteContinueWatching(r.Context(), user.ID, animeID); err != nil { log.Printf("continue watching delete failed: user_id=%s anime_id=%d err=%v", user.ID, animeID, err) http.Error(w, "failed to delete continue watching entry", http.StatusInternalServerError) return } - if err := h.svc.db.SaveWatchProgress(r.Context(), database.SaveWatchProgressParams{ - CurrentEpisode: sql.NullInt64{Valid: false}, - CurrentTimeSeconds: 0, - UserID: user.ID, - AnimeID: animeID, - }); err != nil { - log.Printf("continue watching delete failed to clear watchlist progress: user_id=%s anime_id=%d err=%v", user.ID, animeID, err) - http.Error(w, "failed to delete continue watching entry", http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) } diff --git a/internal/features/watchlist/service.go b/internal/features/watchlist/service.go index baafc79..c4b35c6 100644 --- a/internal/features/watchlist/service.go +++ b/internal/features/watchlist/service.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "strings" "time" "github.com/google/uuid" @@ -13,7 +14,8 @@ import ( ) type Service struct { - db database.Querier + db database.Querier + sqlDB *sql.DB } var ( @@ -29,8 +31,8 @@ var validStatuses = map[string]struct{}{ "plan_to_watch": {}, } -func NewService(db database.Querier) *Service { - return &Service{db: db} +func NewService(db database.Querier, sqlDB *sql.DB) *Service { + return &Service{db: db, sqlDB: sqlDB} } type AddRequest struct { @@ -66,11 +68,11 @@ func (s *Service) AddEntry(ctx context.Context, userID string, req AddRequest) e entryID := uuid.New().String() _, err = s.db.UpsertWatchListEntry(ctx, database.UpsertWatchListEntryParams{ - ID: entryID, - UserID: userID, - AnimeID: req.AnimeID, - Status: req.Status, - CurrentEpisode: sql.NullInt64{Int64: 0, Valid: false}, + ID: entryID, + UserID: userID, + AnimeID: req.AnimeID, + Status: req.Status, + CurrentEpisode: sql.NullInt64{Int64: 0, Valid: false}, CurrentTimeSeconds: 0, }) if err != nil { @@ -109,6 +111,78 @@ func (s *Service) GetUserWatchlist(ctx context.Context, userID string) ([]databa return entries, nil } +func (s *Service) GetContinueWatching(ctx context.Context, userID string) ([]database.GetContinueWatchingEntriesRow, error) { + if strings.TrimSpace(userID) == "" { + return nil, errors.New("invalid user id") + } + + entries, err := s.db.GetContinueWatchingEntries(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to fetch continue watching: %w", err) + } + + return entries, nil +} + +func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error { + if strings.TrimSpace(userID) == "" { + return errors.New("invalid user id") + } + + if animeID <= 0 { + return ErrInvalidAnimeID + } + + if s.sqlDB == nil { + if err := s.db.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + }); err != nil { + return fmt.Errorf("failed to delete continue watching entry: %w", err) + } + + if err := s.db.SaveWatchProgress(ctx, database.SaveWatchProgressParams{ + CurrentEpisode: sql.NullInt64{Valid: false}, + CurrentTimeSeconds: 0, + UserID: userID, + AnimeID: animeID, + }); err != nil { + return fmt.Errorf("failed to clear watchlist progress: %w", err) + } + + return nil + } + + tx, err := s.sqlDB.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + txQueries := database.New(tx) + if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{ + UserID: userID, + AnimeID: animeID, + }); err != nil { + return fmt.Errorf("failed to delete continue watching entry: %w", err) + } + + if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{ + CurrentEpisode: sql.NullInt64{Valid: false}, + CurrentTimeSeconds: 0, + UserID: userID, + AnimeID: animeID, + }); err != nil { + return fmt.Errorf("failed to clear watchlist progress: %w", err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit continue watching deletion: %w", err) + } + + return nil +} + type ExportEntry struct { AnimeID int64 `json:"anime_id"` Title string `json:"title"` @@ -161,11 +235,11 @@ func (s *Service) Import(ctx context.Context, userID string, export ExportData) } _, err = s.db.UpsertWatchListEntry(ctx, database.UpsertWatchListEntryParams{ - ID: uuid.New().String(), - UserID: userID, - AnimeID: entry.AnimeID, - Status: entry.Status, - CurrentEpisode: sql.NullInt64{Int64: 0, Valid: false}, + ID: uuid.New().String(), + UserID: userID, + AnimeID: entry.AnimeID, + Status: entry.Status, + CurrentEpisode: sql.NullInt64{Int64: 0, Valid: false}, CurrentTimeSeconds: 0, }) if err != nil { diff --git a/internal/features/watchlist/service_test.go b/internal/features/watchlist/service_test.go index faf5f3e..5a69be5 100644 --- a/internal/features/watchlist/service_test.go +++ b/internal/features/watchlist/service_test.go @@ -34,7 +34,7 @@ func TestAddEntry_RejectsInvalidAnimeID(t *testing.T) { t.Parallel() q := &fakeQuerier{} - svc := NewService(q) + svc := NewService(q, nil) err := svc.AddEntry(context.Background(), "user-1", AddRequest{ AnimeID: 0, @@ -54,7 +54,7 @@ func TestAddEntry_RejectsInvalidStatus(t *testing.T) { t.Parallel() q := &fakeQuerier{} - svc := NewService(q) + svc := NewService(q, nil) err := svc.AddEntry(context.Background(), "user-1", AddRequest{ AnimeID: 1, @@ -101,7 +101,7 @@ func TestExport_UsesDisplayTitleFallbackOrder(t *testing.T) { }, } - svc := NewService(q) + svc := NewService(q, nil) export, err := svc.Export(context.Background(), "user-1") if err != nil { t.Fatalf("expected no error, got %v", err) diff --git a/internal/jikan/studio.go b/internal/jikan/studio.go index 352a6d7..f27ff53 100644 --- a/internal/jikan/studio.go +++ b/internal/jikan/studio.go @@ -7,7 +7,7 @@ import ( type ProducerResponse struct { Data struct { - MalID int `json:"mal_id"` + MalID int `json:"mal_id"` Titles []struct { Type string `json:"type"` Title string `json:"title"` @@ -40,7 +40,7 @@ func (c *Client) GetAnimeByProducer(ctx context.Context, producerID int, page in } var stale StudioAnimeResult - hasStale := c.getStaleCache(ctx, cacheKey, &cached) + hasStale := c.getStaleCache(ctx, cacheKey, &stale) var result SearchResponse reqURL := fmt.Sprintf("%s/anime?producers=%d&page=%d", c.baseURL, producerID, page) @@ -84,7 +84,7 @@ func (c *Client) GetProducerByID(ctx context.Context, producerID int) (ProducerR } var stale ProducerResponse - hasStale := c.getStaleCache(ctx, cacheKey, &cached) + hasStale := c.getStaleCache(ctx, cacheKey, &stale) var result ProducerResponse reqURL := fmt.Sprintf("%s/producers/%d/full", c.baseURL, producerID) diff --git a/internal/jikan/studio_test.go b/internal/jikan/studio_test.go new file mode 100644 index 0000000..34da87c --- /dev/null +++ b/internal/jikan/studio_test.go @@ -0,0 +1,96 @@ +package jikan + +import ( + "context" + "database/sql" + "net/http" + "net/http/httptest" + "testing" + + "mal/internal/database" +) + +type staleCacheQuerier struct { + database.Querier + staleJSON string +} + +func (q *staleCacheQuerier) GetJikanCache(ctx context.Context, key string) (string, error) { + return "", sql.ErrNoRows +} + +func (q *staleCacheQuerier) GetJikanCacheStale(ctx context.Context, key string) (string, error) { + if q.staleJSON == "" { + return "", sql.ErrNoRows + } + + return q.staleJSON, nil +} + +func TestGetProducerByID_UsesStaleCacheOnFetchFailure(t *testing.T) { + t.Parallel() + + q := &staleCacheQuerier{ + staleJSON: `{"data":{"mal_id":7,"about":"stale about"}}`, + } + + client := NewClient(q) + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer testServer.Close() + + client.baseURL = testServer.URL + client.httpClient = testServer.Client() + + result, err := client.GetProducerByID(context.Background(), 7) + if err != nil { + t.Fatalf("expected stale cache result, got error: %v", err) + } + + if result.Data.MalID != 7 { + t.Fatalf("expected stale mal_id 7, got %d", result.Data.MalID) + } + + if result.Data.About != "stale about" { + t.Fatalf("expected stale about field, got %q", result.Data.About) + } +} + +func TestGetAnimeByProducer_UsesStaleCacheOnFetchFailure(t *testing.T) { + t.Parallel() + + q := &staleCacheQuerier{ + staleJSON: `{"Animes":[{"mal_id":42,"title":"Stale Anime"}],"HasNextPage":true,"StudioName":"Stale Studio"}`, + } + + client := NewClient(q) + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer testServer.Close() + + client.baseURL = testServer.URL + client.httpClient = testServer.Client() + + result, err := client.GetAnimeByProducer(context.Background(), 9, 1) + if err != nil { + t.Fatalf("expected stale cache result, got error: %v", err) + } + + if len(result.Animes) != 1 { + t.Fatalf("expected one stale anime, got %d", len(result.Animes)) + } + + if result.Animes[0].MalID != 42 { + t.Fatalf("expected stale anime mal_id 42, got %d", result.Animes[0].MalID) + } + + if !result.HasNextPage { + t.Fatal("expected stale has_next_page=true") + } + + if result.StudioName != "Stale Studio" { + t.Fatalf("expected stale studio name, got %q", result.StudioName) + } +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 83c09c3..b6608fb 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "database/sql" "net/http" "mal/internal/database" @@ -13,9 +14,11 @@ import ( ) type Config struct { - DB *database.Queries - JikanClient *jikan.Client - AuthService *auth.Service + DB *database.Queries + SQLDB *sql.DB + JikanClient *jikan.Client + AuthService *auth.Service + PlaybackProxySecret string } func NewRouter(cfg Config) http.Handler { @@ -23,12 +26,12 @@ func NewRouter(cfg Config) http.Handler { authHandler := auth.NewHandler(cfg.AuthService) - watchlistSvc := watchlist.NewService(cfg.DB) + watchlistSvc := watchlist.NewService(cfg.DB, cfg.SQLDB) watchlistHandler := watchlist.NewHandler(watchlistSvc) animeSvc := anime.NewService(cfg.JikanClient, cfg.DB) animeHandler := anime.NewHandler(animeSvc) - playbackSvc := playback.NewService(cfg.DB) + playbackSvc := playback.NewService(cfg.DB, cfg.SQLDB, playback.Config{ProxyTokenSecret: cfg.PlaybackProxySecret}) playbackHandler := playback.NewHandler(playbackSvc, cfg.JikanClient) // Serve static files @@ -83,5 +86,6 @@ func NewRouter(cfg Config) http.Handler { // Wrap mux with global CSRF origin verification and auth checking, // THEN auth context parsing. protectedHandler := middleware.RequireGlobalAuth(middleware.VerifyOrigin(mux)) - return middleware.Auth(cfg.AuthService)(protectedHandler) + authenticatedHandler := middleware.Auth(cfg.AuthService)(protectedHandler) + return middleware.RequestLogger(authenticatedHandler) } diff --git a/internal/templates/watch.templ b/internal/templates/watch.templ index f9151ad..eb801df 100644 --- a/internal/templates/watch.templ +++ b/internal/templates/watch.templ @@ -29,16 +29,14 @@ type WatchPageData struct { // ModeSource represents a stream source for a specific mode (dub/sub) type ModeSource struct { - URL string `json:"url"` - Referer string `json:"referer"` + Token string `json:"token"` Subtitles []SubtitleItem `json:"subtitles"` } // SubtitleItem represents a subtitle track type SubtitleItem struct { Lang string `json:"lang"` - URL string `json:"url"` - Referer string `json:"referer"` + Token string `json:"token"` } // SkipSegment represents a skippable segment (intro/outro) @@ -199,7 +197,7 @@ templ EpisodeItem(episode jikan.Episode, currentEpisode string, animeID int) { } templ VideoPlayer(data WatchPageData) { - {{ streamURL := buildStreamURL(data.InitialMode, data.ModeSources) }} + {{ streamToken := modeToken(data.InitialMode, data.ModeSources) }} {{ hasDub := modeAvailable(data.AvailableModes, "dub") }} {{ hasSub := modeAvailable(data.AvailableModes, "sub") }}
@@ -339,9 +338,35 @@ templ VideoPlayer(data WatchPageData) {
} -func buildStreamURL(mode string, modeSources map[string]ModeSource) string { - stateJSON, _ := json.Marshal(modeSources) - return fmt.Sprintf("/watch/proxy/stream?mode=%s&state=%s", url.QueryEscape(mode), url.QueryEscape(string(stateJSON))) +func buildStreamURL(mode string, token string) string { + if token == "" { + return "" + } + + return fmt.Sprintf("/watch/proxy/stream?mode=%s&token=%s", url.QueryEscape(mode), url.QueryEscape(token)) +} + +func modeToken(mode string, modeSources map[string]ModeSource) string { + normalizedMode := mode + if _, ok := modeSources[normalizedMode]; !ok { + if _, ok := modeSources["dub"]; ok { + normalizedMode = "dub" + } else if _, ok := modeSources["sub"]; ok { + normalizedMode = "sub" + } else { + for key := range modeSources { + normalizedMode = key + break + } + } + } + + source, ok := modeSources[normalizedMode] + if !ok { + return "" + } + + return source.Token } func toJSON(v interface{}) string { diff --git a/static/player.ts b/static/player.ts index 3cef0c4..48cbcb5 100644 --- a/static/player.ts +++ b/static/player.ts @@ -1,15 +1,13 @@ export {} interface ModeSource { - url: string - referer: string + token: string subtitles: SubtitleItem[] } interface SubtitleItem { lang: string - url: string - referer: string + token: string } interface SkipSegment { @@ -46,6 +44,7 @@ const initPlayer = (): void => { const subtitleText = container.querySelector('[data-subtitle-text]') as HTMLElement const streamURL = container.getAttribute('data-stream-url') || '/watch/proxy/stream' + const initialStreamToken = container.getAttribute('data-stream-token') || '' const currentEpisode = container.getAttribute('data-current-episode') || '1' const malID = Number.parseInt(container.getAttribute('data-mal-id') || '', 10) const totalEpisodes = Number.parseInt(container.getAttribute('data-total-episodes') || '0', 10) @@ -79,6 +78,10 @@ const initPlayer = (): void => { .sort((a: { start: number }, b: { start: number }) => a.start - b.start) let currentMode = availableModes.includes(initialMode) ? initialMode : (availableModes[0] || 'dub') + const fallbackMode = Object.keys(modeSources).find((mode) => typeof modeSources[mode]?.token === 'string' && modeSources[mode].token !== '') + if ((!modeSources[currentMode] || !modeSources[currentMode].token) && fallbackMode) { + currentMode = fallbackMode + } let controlsTimeout: number | undefined let isScrubbing = false let isHoveringVolume = false @@ -97,21 +100,18 @@ const initPlayer = (): void => { const previewPopover = container.querySelector('[data-preview-popover]') as HTMLElement const previewTime = container.querySelector('[data-preview-time]') as HTMLElement - const encodedModeState = encodeURIComponent(JSON.stringify(modeSources)) - const streamUrlForMode = (mode: string): string => { const modeParam = encodeURIComponent(mode) - const stateParam = encodedModeState - return `${streamURL}?mode=${modeParam}&state=${stateParam}` + const modeSource = modeSources[mode] + const token = modeSource?.token + if (!token) return '' + const tokenParam = encodeURIComponent(token) + return `${streamURL}?mode=${modeParam}&token=${tokenParam}` } const subtitleProxyURL = (track: SubtitleItem): string => { - if (!track || !track.url) return '' - let proxied = `/watch/proxy/subtitle?u=${encodeURIComponent(track.url)}` - if (track.referer) { - proxied += `&r=${encodeURIComponent(track.referer)}` - } - return proxied + if (!track || !track.token) return '' + return `/watch/proxy/subtitle?token=${encodeURIComponent(track.token)}` } const subtitlesForMode = (mode: string): Array<{ lang: string, label: string, url: string }> => { @@ -560,11 +560,13 @@ const initPlayer = (): void => { const switchMode = (mode: string): void => { if (!availableModes.includes(mode) || mode === currentMode) return + const nextURL = streamUrlForMode(mode) + if (!nextURL) return const wasPlaying = !video.paused const previousTime = displayTimeFromAbsolute(video.currentTime) currentMode = mode hidePreviewPopover() - video.src = streamUrlForMode(currentMode) + video.src = nextURL video.load() pendingSeekTime = previousTime if (wasPlaying) video.play().catch(() => {}) @@ -653,6 +655,13 @@ const initPlayer = (): void => { updateSubtitleOptions() updateModeButtons(currentMode) + const startingURL = streamUrlForMode(currentMode) + if (startingURL) { + video.src = startingURL + } else if (initialStreamToken) { + video.src = `${streamURL}?mode=${encodeURIComponent(currentMode)}&token=${encodeURIComponent(initialStreamToken)}` + } + if (video) { video.addEventListener('loadedmetadata', () => { if (loading) loading.style.display = 'none'