fix: harden playback and migrations
This commit is contained in:
@@ -98,7 +98,7 @@ The frontend pipeline uses a single source stylesheet (`static/style.css`) and T
|
||||
When the server starts, the app is available at `http://localhost:3000`.
|
||||
|
||||
Important notes:
|
||||
- Environment variables are read directly from the process environment (`PORT`, `DATABASE_FILE`, `ENV`); `.env` is not auto-loaded.
|
||||
- Environment variables are read directly from the process environment (`PORT`, `DATABASE_FILE`, `ENV`, `PLAYBACK_PROXY_SECRET`); `.env` is not auto-loaded.
|
||||
- The web app currently exposes a login route only. If your database has no users yet, create the first user outside the web UI.
|
||||
|
||||
For containerized usage, the included `Dockerfile` uses a multi-stage build that installs Bun + templ, builds assets, generates templates, compiles `cmd/server`, and ships a slim runtime image with SQLite support.
|
||||
@@ -125,6 +125,8 @@ docker run --rm \
|
||||
| `PORT` | `3000` | HTTP listen port |
|
||||
| `DATABASE_FILE` | `mal.db` | SQLite database file path |
|
||||
| `ENV` | _(empty)_ | Set to `production` to enable secure session cookies |
|
||||
| `MIGRATIONS_DIR` | _(auto-discovered)_ | Optional explicit path to migration files |
|
||||
| `PLAYBACK_PROXY_SECRET` | _(required)_ | HMAC secret for signed playback proxy tokens (min 32 chars) |
|
||||
|
||||
## Database and testing
|
||||
|
||||
|
||||
@@ -3,10 +3,14 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -20,23 +24,49 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
dbFile := os.Getenv("DATABASE_FILE")
|
||||
dbFile := strings.TrimSpace(os.Getenv("DATABASE_FILE"))
|
||||
if dbFile == "" {
|
||||
dbFile = "mal.db"
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
dsn := fmt.Sprintf("file:%s?_foreign_keys=on", dbFile)
|
||||
db, err := sql.Open("sqlite3", dsn)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
log.Fatalf("failed to ping db: %v", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
log.Fatalf("failed to enforce sqlite foreign keys: %v", err)
|
||||
}
|
||||
|
||||
var fkState int
|
||||
if err := db.QueryRow("PRAGMA foreign_keys").Scan(&fkState); err != nil {
|
||||
log.Fatalf("failed to verify sqlite foreign keys: %v", err)
|
||||
}
|
||||
if fkState != 1 {
|
||||
log.Fatal("sqlite foreign keys are disabled")
|
||||
}
|
||||
|
||||
migrationsDir, err := resolveMigrationsDir()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to locate migrations directory: %v", err)
|
||||
}
|
||||
|
||||
// Run migrations with tracking
|
||||
if err := database.RunMigrations(db); err != nil {
|
||||
if err := database.RunMigrations(db, migrationsDir); err != nil {
|
||||
log.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
playbackSecret := strings.TrimSpace(os.Getenv("PLAYBACK_PROXY_SECRET"))
|
||||
if len(playbackSecret) < 32 {
|
||||
log.Fatal("PLAYBACK_PROXY_SECRET must be set and at least 32 characters")
|
||||
}
|
||||
|
||||
queries := database.New(db)
|
||||
authService := auth.NewService(queries)
|
||||
jikanClient := jikan.NewClient(queries)
|
||||
@@ -51,8 +81,10 @@ func main() {
|
||||
|
||||
app := server.Config{
|
||||
DB: queries,
|
||||
SQLDB: db,
|
||||
JikanClient: jikanClient,
|
||||
AuthService: authService,
|
||||
PlaybackProxySecret: playbackSecret,
|
||||
}
|
||||
|
||||
handler := server.NewRouter(app)
|
||||
@@ -86,3 +118,69 @@ func main() {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveMigrationsDir() (string, error) {
|
||||
configured := strings.TrimSpace(os.Getenv("MIGRATIONS_DIR"))
|
||||
if configured != "" {
|
||||
hasFiles, err := directoryHasSQLFiles(configured)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !hasFiles {
|
||||
return "", fmt.Errorf("MIGRATIONS_DIR has no .sql files: %s", configured)
|
||||
}
|
||||
|
||||
return configured, nil
|
||||
}
|
||||
|
||||
workingDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
executablePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
filepath.Join(workingDir, "migrations"),
|
||||
filepath.Join(filepath.Dir(executablePath), "migrations"),
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
hasFiles, checkErr := directoryHasSQLFiles(candidate)
|
||||
if checkErr != nil {
|
||||
if errors.Is(checkErr, os.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
|
||||
return "", checkErr
|
||||
}
|
||||
|
||||
if hasFiles {
|
||||
return candidate, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("could not find migrations directory")
|
||||
}
|
||||
|
||||
func directoryHasSQLFiles(dir string) (bool, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !info.IsDir() {
|
||||
return false, fmt.Errorf("not a directory: %s", dir)
|
||||
}
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(dir, "*.sql"))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return len(files) > 0, nil
|
||||
}
|
||||
|
||||
@@ -2,13 +2,19 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RunMigrations(db *sql.DB) error {
|
||||
func RunMigrations(db *sql.DB, migrationsDir string) error {
|
||||
if migrationsDir == "" {
|
||||
return fmt.Errorf("migrations directory is required")
|
||||
}
|
||||
|
||||
// Create migration tracking table
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS migration_version (
|
||||
@@ -20,21 +26,24 @@ func RunMigrations(db *sql.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
migrations, err := filepath.Glob("migrations/*.sql")
|
||||
migrations, err := filepath.Glob(filepath.Join(migrationsDir, "*.sql"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(migrations) == 0 {
|
||||
return fmt.Errorf("no migration files found in %s", migrationsDir)
|
||||
}
|
||||
|
||||
sort.Strings(migrations)
|
||||
|
||||
for _, migrationFile := range migrations {
|
||||
// Check if migration already applied
|
||||
var exists int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM migration_version WHERE name = ?", migrationFile).Scan(&exists)
|
||||
appliedNames, err := loadAppliedMigrationNames(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists > 0 {
|
||||
|
||||
for _, migrationFile := range migrations {
|
||||
migrationName := filepath.Base(migrationFile)
|
||||
if migrationApplied(appliedNames, migrationName) {
|
||||
// already applied, skipping silently
|
||||
continue
|
||||
}
|
||||
@@ -51,13 +60,58 @@ func RunMigrations(db *sql.DB) error {
|
||||
}
|
||||
|
||||
// Mark as applied
|
||||
_, err = db.Exec("INSERT INTO migration_version (name) VALUES (?)", migrationFile)
|
||||
_, err = db.Exec("INSERT INTO migration_version (name) VALUES (?)", migrationName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("migration %s applied successfully", migrationFile)
|
||||
appliedNames[migrationName] = struct{}{}
|
||||
|
||||
log.Printf("migration %s applied successfully", migrationName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadAppliedMigrationNames(db *sql.DB) (map[string]struct{}, error) {
|
||||
rows, err := db.Query("SELECT name FROM migration_version")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
applied := make(map[string]struct{})
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applied[name] = struct{}{}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func migrationApplied(appliedNames map[string]struct{}, migrationName string) bool {
|
||||
if _, exists := appliedNames[migrationName]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
legacyName := filepath.ToSlash(filepath.Join("migrations", migrationName))
|
||||
if _, exists := appliedNames[legacyName]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
for appliedName := range appliedNames {
|
||||
if strings.EqualFold(filepath.Base(appliedName), migrationName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"mal/internal/database"
|
||||
"mal/internal/jikan"
|
||||
"mal/internal/shared/middleware"
|
||||
@@ -163,13 +161,11 @@ func convertModeSources(sources map[string]ModeSource) map[string]templates.Mode
|
||||
for i, s := range v.Subtitles {
|
||||
subtitles[i] = templates.SubtitleItem{
|
||||
Lang: s.Lang,
|
||||
URL: s.URL,
|
||||
Referer: s.Referer,
|
||||
Token: s.Token,
|
||||
}
|
||||
}
|
||||
result[k] = templates.ModeSource{
|
||||
URL: v.URL,
|
||||
Referer: v.Referer,
|
||||
Token: v.Token,
|
||||
Subtitles: subtitles,
|
||||
}
|
||||
}
|
||||
@@ -199,25 +195,19 @@ func (h *Handler) HandleProxyStream(w http.ResponseWriter, r *http.Request) {
|
||||
mode = "dub"
|
||||
}
|
||||
|
||||
state := r.URL.Query().Get("state")
|
||||
if strings.TrimSpace(state) == "" {
|
||||
http.Error(w, "missing playback state", http.StatusBadRequest)
|
||||
token := strings.TrimSpace(r.URL.Query().Get("token"))
|
||||
if token == "" {
|
||||
http.Error(w, "missing playback token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
modeSources := make(map[string]ModeSource)
|
||||
if err := json.Unmarshal([]byte(state), &modeSources); err != nil {
|
||||
http.Error(w, "invalid playback state", http.StatusBadRequest)
|
||||
targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeStream)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid stream token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
source, ok := modeSources[mode]
|
||||
if !ok || strings.TrimSpace(source.URL) == "" {
|
||||
http.Error(w, "stream mode unavailable", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h.proxyUpstream(w, r, source.URL, source.Referer)
|
||||
h.proxyUpstream(w, r, targetURL, referer)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleProxySegment(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -226,13 +216,19 @@ func (h *Handler) HandleProxySegment(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := r.URL.Query().Get("u")
|
||||
if strings.TrimSpace(targetURL) == "" {
|
||||
http.Error(w, "missing target url", http.StatusBadRequest)
|
||||
token := strings.TrimSpace(r.URL.Query().Get("token"))
|
||||
if token == "" {
|
||||
http.Error(w, "missing segment token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h.proxyUpstream(w, r, targetURL, r.URL.Query().Get("r"))
|
||||
targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeSegment)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid segment token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h.proxyUpstream(w, r, targetURL, referer)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleProxySubtitle(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -241,13 +237,19 @@ func (h *Handler) HandleProxySubtitle(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := r.URL.Query().Get("u")
|
||||
if strings.TrimSpace(targetURL) == "" {
|
||||
http.Error(w, "missing target url", http.StatusBadRequest)
|
||||
token := strings.TrimSpace(r.URL.Query().Get("token"))
|
||||
if token == "" {
|
||||
http.Error(w, "missing subtitle token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h.proxyUpstream(w, r, targetURL, r.URL.Query().Get("r"))
|
||||
targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, proxyScopeSubtitle)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid subtitle token", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
h.proxyUpstream(w, r, targetURL, referer)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleSaveProgress(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -291,62 +293,15 @@ func (h *Handler) HandleSaveProgress(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
animeID := int64(payload.MalID)
|
||||
|
||||
if _, err := h.svc.db.GetAnime(r.Context(), animeID); err != nil {
|
||||
anime, fetchErr := h.jikanClient.GetAnimeByID(r.Context(), payload.MalID)
|
||||
if fetchErr != nil {
|
||||
log.Printf("save progress failed to fetch anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, fetchErr)
|
||||
animeSeed, err := h.ensureAnimeSeed(r.Context(), payload.MalID)
|
||||
if err != nil {
|
||||
log.Printf("save progress failed to resolve anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to save progress", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, upsertErr := h.svc.db.UpsertAnime(r.Context(), database.UpsertAnimeParams{
|
||||
ID: animeID,
|
||||
TitleOriginal: anime.Title,
|
||||
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
|
||||
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
|
||||
ImageUrl: anime.ImageURL(),
|
||||
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
|
||||
}); upsertErr != nil {
|
||||
log.Printf("save progress failed to upsert anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, upsertErr)
|
||||
http.Error(w, "failed to save progress", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
watchListEntry, watchListErr := h.svc.db.GetWatchListEntry(r.Context(), database.GetWatchListEntryParams{
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
})
|
||||
if watchListErr == nil && watchListEntry.Status == "completed" {
|
||||
if err := h.svc.db.DeleteContinueWatchingEntry(r.Context(), database.DeleteContinueWatchingEntryParams{
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
log.Printf("save progress failed to clear continue entry user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.db.SaveWatchProgress(r.Context(), database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true},
|
||||
CurrentTimeSeconds: timeSeconds,
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
if err.Error() != "sql: no rows in result set" {
|
||||
log.Printf("save watchlist progress skipped user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := h.svc.db.UpsertContinueWatchingEntry(r.Context(), database.UpsertContinueWatchingEntryParams{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true},
|
||||
CurrentTimeSeconds: timeSeconds,
|
||||
}); err != nil {
|
||||
log.Printf("save continue watching failed user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
if err := h.svc.SaveProgress(r.Context(), user.ID, animeID, payload.Episode, timeSeconds, animeSeed); err != nil {
|
||||
log.Printf("save progress failed user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to save progress", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -383,89 +338,41 @@ func (h *Handler) HandleCompleteAnime(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
animeID := int64(payload.MalID)
|
||||
watchListEntry, watchListErr := h.svc.db.GetWatchListEntry(r.Context(), database.GetWatchListEntryParams{
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
})
|
||||
if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) {
|
||||
log.Printf("complete anime failed to load watchlist user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, watchListErr)
|
||||
animeSeed, err := h.ensureAnimeSeed(r.Context(), payload.MalID)
|
||||
if err != nil {
|
||||
log.Printf("complete anime failed to resolve anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
alreadyCompleted := watchListErr == nil && watchListEntry.Status == "completed"
|
||||
|
||||
if !alreadyCompleted {
|
||||
if _, err := h.svc.db.GetAnime(r.Context(), animeID); err != nil {
|
||||
anime, fetchErr := h.jikanClient.GetAnimeByID(r.Context(), payload.MalID)
|
||||
if fetchErr != nil {
|
||||
log.Printf("complete anime failed to fetch anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, fetchErr)
|
||||
if err := h.svc.CompleteAnime(r.Context(), user.ID, animeID, payload.Episode, animeSeed); err != nil {
|
||||
log.Printf("complete anime failed user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, upsertErr := h.svc.db.UpsertAnime(r.Context(), database.UpsertAnimeParams{
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *Handler) ensureAnimeSeed(ctx context.Context, malID int) (*database.UpsertAnimeParams, error) {
|
||||
animeID := int64(malID)
|
||||
if _, err := h.svc.db.GetAnime(ctx, animeID); err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
anime, err := h.jikanClient.GetAnimeByID(ctx, malID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &database.UpsertAnimeParams{
|
||||
ID: animeID,
|
||||
TitleOriginal: anime.Title,
|
||||
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
|
||||
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
|
||||
ImageUrl: anime.ImageURL(),
|
||||
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
|
||||
}); upsertErr != nil {
|
||||
log.Printf("complete anime failed to upsert anime user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, upsertErr)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := h.svc.db.UpsertWatchListEntry(r.Context(), database.UpsertWatchListEntryParams{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
Status: "completed",
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true},
|
||||
CurrentTimeSeconds: 0,
|
||||
}); err != nil {
|
||||
log.Printf("complete anime failed to upsert watchlist user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.svc.db.DeleteContinueWatchingEntry(r.Context(), database.DeleteContinueWatchingEntryParams{
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
log.Printf("complete anime failed to delete continue entry user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := h.svc.db.GetContinueWatchingEntry(r.Context(), database.GetContinueWatchingEntryParams{
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
}); err == nil {
|
||||
log.Printf("complete anime failed to clear continue entry user_id=%s mal_id=%d", user.ID, payload.MalID)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
} else if !errors.Is(err, sql.ErrNoRows) {
|
||||
log.Printf("complete anime failed to verify continue clear user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
http.Error(w, "failed to mark anime completed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.db.SaveWatchProgress(r.Context(), database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(payload.Episode), Valid: true},
|
||||
CurrentTimeSeconds: 0,
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
log.Printf("complete anime failed to reset watchlist progress user_id=%s mal_id=%d err=%v", user.ID, payload.MalID, err)
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Handler) proxyUpstream(w http.ResponseWriter, r *http.Request, targetURL string, referer string) {
|
||||
|
||||
157
internal/features/playback/progress.go
Normal file
157
internal/features/playback/progress.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"mal/internal/database"
|
||||
)
|
||||
|
||||
func (s *Service) SaveProgress(ctx context.Context, userID string, animeID int64, episode int, timeSeconds float64, animeSeed *database.UpsertAnimeParams) error {
|
||||
if strings.TrimSpace(userID) == "" || animeID <= 0 || episode <= 0 {
|
||||
return errors.New("invalid save progress input")
|
||||
}
|
||||
|
||||
txQueries, tx, err := s.beginTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer tx.Rollback()
|
||||
|
||||
if animeSeed != nil {
|
||||
if _, err := txQueries.UpsertAnime(ctx, *animeSeed); err != nil {
|
||||
return fmt.Errorf("failed to save anime reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
watchListEntry, watchListErr := txQueries.GetWatchListEntry(ctx, database.GetWatchListEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
})
|
||||
if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) {
|
||||
return fmt.Errorf("failed to load watchlist entry: %w", watchListErr)
|
||||
}
|
||||
|
||||
if watchListErr == nil && watchListEntry.Status == "completed" {
|
||||
if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to clear continue entry: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit save progress transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true},
|
||||
CurrentTimeSeconds: timeSeconds,
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return fmt.Errorf("failed to save watchlist progress: %w", err)
|
||||
}
|
||||
|
||||
if _, err := txQueries.UpsertContinueWatchingEntry(ctx, database.UpsertContinueWatchingEntryParams{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true},
|
||||
CurrentTimeSeconds: timeSeconds,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to upsert continue entry: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit save progress transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) CompleteAnime(ctx context.Context, userID string, animeID int64, episode int, animeSeed *database.UpsertAnimeParams) error {
|
||||
if strings.TrimSpace(userID) == "" || animeID <= 0 || episode <= 0 {
|
||||
return errors.New("invalid complete anime input")
|
||||
}
|
||||
|
||||
txQueries, tx, err := s.beginTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer tx.Rollback()
|
||||
|
||||
watchListEntry, watchListErr := txQueries.GetWatchListEntry(ctx, database.GetWatchListEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
})
|
||||
if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) {
|
||||
return fmt.Errorf("failed to load watchlist entry: %w", watchListErr)
|
||||
}
|
||||
|
||||
alreadyCompleted := watchListErr == nil && watchListEntry.Status == "completed"
|
||||
|
||||
if !alreadyCompleted {
|
||||
if animeSeed != nil {
|
||||
if _, err := txQueries.UpsertAnime(ctx, *animeSeed); err != nil {
|
||||
return fmt.Errorf("failed to save anime reference: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := txQueries.UpsertWatchListEntry(ctx, database.UpsertWatchListEntryParams{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
Status: "completed",
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true},
|
||||
CurrentTimeSeconds: 0,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to mark watchlist as completed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to clear continue entry: %w", err)
|
||||
}
|
||||
|
||||
if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true},
|
||||
CurrentTimeSeconds: 0,
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return fmt.Errorf("failed to reset watch progress: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit complete anime transaction: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) beginTx(ctx context.Context) (*database.Queries, *sql.Tx, error) {
|
||||
if s.sqlDB == nil {
|
||||
return nil, nil, errors.New("database unavailable")
|
||||
}
|
||||
|
||||
tx, err := s.sqlDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
return database.New(tx), tx, nil
|
||||
}
|
||||
348
internal/features/playback/proxy_security.go
Normal file
348
internal/features/playback/proxy_security.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
proxyStreamTokenTTL = 2 * time.Hour
|
||||
proxySegmentTokenTTL = 6 * time.Hour
|
||||
proxySubtitleTokenTTL = 6 * time.Hour
|
||||
)
|
||||
const proxyHostCheckTTL = 2 * time.Minute
|
||||
|
||||
type proxyScope string
|
||||
|
||||
const (
|
||||
proxyScopeStream proxyScope = "stream"
|
||||
proxyScopeSegment proxyScope = "segment"
|
||||
proxyScopeSubtitle proxyScope = "subtitle"
|
||||
)
|
||||
|
||||
type proxyTokenPayload struct {
|
||||
TargetURL string `json:"u"`
|
||||
Referer string `json:"r,omitempty"`
|
||||
Scope string `json:"s"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
}
|
||||
|
||||
type proxyTokenSigner struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func newProxyTokenSigner(secret string) (*proxyTokenSigner, error) {
|
||||
trimmed := strings.TrimSpace(secret)
|
||||
if trimmed == "" {
|
||||
return nil, errors.New("proxy token secret is required")
|
||||
}
|
||||
|
||||
if len(trimmed) < 32 {
|
||||
return nil, errors.New("proxy token secret must be at least 32 characters")
|
||||
}
|
||||
|
||||
return &proxyTokenSigner{secret: []byte(trimmed)}, nil
|
||||
}
|
||||
|
||||
func (s *proxyTokenSigner) Sign(payload proxyTokenPayload) (string, error) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal proxy token payload: %w", err)
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, s.secret)
|
||||
mac.Write(body)
|
||||
signature := mac.Sum(nil)
|
||||
|
||||
encodedBody := base64.RawURLEncoding.EncodeToString(body)
|
||||
encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
|
||||
return encodedBody + "." + encodedSignature, nil
|
||||
}
|
||||
|
||||
func (s *proxyTokenSigner) Verify(token string) (proxyTokenPayload, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return proxyTokenPayload{}, errors.New("invalid proxy token format")
|
||||
}
|
||||
|
||||
body, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return proxyTokenPayload{}, errors.New("invalid proxy token payload")
|
||||
}
|
||||
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return proxyTokenPayload{}, errors.New("invalid proxy token signature")
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, s.secret)
|
||||
mac.Write(body)
|
||||
expected := mac.Sum(nil)
|
||||
if !hmac.Equal(signature, expected) {
|
||||
return proxyTokenPayload{}, errors.New("invalid proxy token signature")
|
||||
}
|
||||
|
||||
var payload proxyTokenPayload
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return proxyTokenPayload{}, errors.New("invalid proxy token payload")
|
||||
}
|
||||
|
||||
if payload.ExpiresAt <= time.Now().Unix() {
|
||||
return proxyTokenPayload{}, errors.New("proxy token expired")
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (s *Service) buildClientModeSources(modeSources map[string]ModeSource) (map[string]ModeSource, error) {
|
||||
clientModeSources := make(map[string]ModeSource, len(modeSources))
|
||||
|
||||
for mode, source := range modeSources {
|
||||
streamToken, err := s.issueProxyToken(source.URL, source.Referer, proxyScopeStream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subtitles := make([]SubtitleItem, 0, len(source.Subtitles))
|
||||
for _, subtitle := range source.Subtitles {
|
||||
targetURL := strings.TrimSpace(subtitle.URL)
|
||||
if targetURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.issueProxyToken(targetURL, source.Referer, proxyScopeSubtitle)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subtitles = append(subtitles, SubtitleItem{
|
||||
Lang: subtitle.Lang,
|
||||
Token: token,
|
||||
})
|
||||
}
|
||||
|
||||
clientModeSources[mode] = ModeSource{
|
||||
Token: streamToken,
|
||||
Subtitles: subtitles,
|
||||
}
|
||||
}
|
||||
|
||||
return clientModeSources, nil
|
||||
}
|
||||
|
||||
func (s *Service) issueProxyToken(targetURL string, referer string, scope proxyScope) (string, error) {
|
||||
normalizedTarget, err := normalizeProxyURL(targetURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
normalizedReferer := ""
|
||||
if strings.TrimSpace(referer) != "" {
|
||||
refererURL, refererErr := normalizeProxyURL(referer)
|
||||
if refererErr == nil {
|
||||
normalizedReferer = refererURL
|
||||
}
|
||||
}
|
||||
|
||||
return s.proxyTokens.Sign(proxyTokenPayload{
|
||||
TargetURL: normalizedTarget,
|
||||
Referer: normalizedReferer,
|
||||
Scope: string(scope),
|
||||
ExpiresAt: time.Now().Add(proxyTokenTTL(scope)).Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
func proxyTokenTTL(scope proxyScope) time.Duration {
|
||||
switch scope {
|
||||
case proxyScopeStream:
|
||||
return proxyStreamTokenTTL
|
||||
case proxyScopeSegment:
|
||||
return proxySegmentTokenTTL
|
||||
case proxyScopeSubtitle:
|
||||
return proxySubtitleTokenTTL
|
||||
default:
|
||||
return proxyStreamTokenTTL
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) resolveProxyToken(ctx context.Context, token string, scope proxyScope) (string, string, error) {
|
||||
payload, err := s.proxyTokens.Verify(token)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if payload.Scope != string(scope) {
|
||||
return "", "", errors.New("proxy token scope mismatch")
|
||||
}
|
||||
|
||||
normalizedTarget, err := normalizeProxyURL(payload.TargetURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if err := s.ensurePublicProxyTarget(ctx, normalizedTarget); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
normalizedReferer := ""
|
||||
if strings.TrimSpace(payload.Referer) != "" {
|
||||
refererURL, refererErr := normalizeProxyURL(payload.Referer)
|
||||
if refererErr == nil {
|
||||
if ensureErr := s.ensurePublicProxyTarget(ctx, refererURL); ensureErr == nil {
|
||||
normalizedReferer = refererURL
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedTarget, normalizedReferer, nil
|
||||
}
|
||||
|
||||
func normalizeProxyURL(rawURL string) (string, error) {
|
||||
parsed, err := url.Parse(strings.TrimSpace(rawURL))
|
||||
if err != nil {
|
||||
return "", errors.New("invalid proxy target")
|
||||
}
|
||||
|
||||
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||
return "", errors.New("invalid proxy target scheme")
|
||||
}
|
||||
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host == "" {
|
||||
return "", errors.New("invalid proxy target host")
|
||||
}
|
||||
|
||||
if host == "localhost" || strings.HasSuffix(host, ".localhost") || strings.HasSuffix(host, ".local") {
|
||||
return "", errors.New("localhost targets are not allowed")
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && isBlockedProxyIP(ip) {
|
||||
return "", errors.New("private proxy targets are not allowed")
|
||||
}
|
||||
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func isBlockedProxyIP(ip net.IP) bool {
|
||||
return ip.IsLoopback() ||
|
||||
ip.IsPrivate() ||
|
||||
ip.IsMulticast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsUnspecified()
|
||||
}
|
||||
|
||||
func (s *Service) ensurePublicProxyTarget(ctx context.Context, rawURL string) error {
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return errors.New("invalid proxy target")
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return errors.New("invalid proxy target host")
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isBlockedProxyIP(ip) {
|
||||
return errors.New("private proxy targets are not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
s.proxyHostMu.RLock()
|
||||
cached, ok := s.proxyHostCache[host]
|
||||
s.proxyHostMu.RUnlock()
|
||||
if ok && now.Before(cached.ExpiresAt) {
|
||||
if cached.Allowed {
|
||||
return nil
|
||||
}
|
||||
return errors.New("private proxy targets are not allowed")
|
||||
}
|
||||
|
||||
resolvedIPs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||
if err != nil || len(resolvedIPs) == 0 {
|
||||
return errors.New("proxy target lookup failed")
|
||||
}
|
||||
|
||||
allowed := true
|
||||
for _, resolved := range resolvedIPs {
|
||||
if isBlockedProxyIP(resolved.IP) {
|
||||
allowed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.proxyHostMu.Lock()
|
||||
s.proxyHostCache[host] = proxyHostCacheItem{
|
||||
Allowed: allowed,
|
||||
ExpiresAt: now.Add(proxyHostCheckTTL),
|
||||
}
|
||||
s.proxyHostMu.Unlock()
|
||||
|
||||
if !allowed {
|
||||
return errors.New("private proxy targets are not allowed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) rewritePlaylistWithTokens(ctx context.Context, content string, baseURL string, referer string) (string, error) {
|
||||
base, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
scanner := bufio.NewScanner(strings.NewReader(content))
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
||||
out.WriteString(line)
|
||||
out.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
relativeURL, parseErr := url.Parse(trimmed)
|
||||
if parseErr != nil {
|
||||
out.WriteString(line)
|
||||
out.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
absoluteURL := base.ResolveReference(relativeURL).String()
|
||||
token, tokenErr := s.issueProxyToken(absoluteURL, referer, proxyScopeSegment)
|
||||
if tokenErr != nil {
|
||||
return "", tokenErr
|
||||
}
|
||||
|
||||
proxied := "/watch/proxy/segment?token=" + url.QueryEscape(token)
|
||||
out.WriteString(proxied)
|
||||
out.WriteString("\n")
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return out.String(), nil
|
||||
}
|
||||
45
internal/features/playback/proxy_security_test.go
Normal file
45
internal/features/playback/proxy_security_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"mal/internal/database"
|
||||
)
|
||||
|
||||
func TestNormalizeProxyURLRejectsLocalhost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := normalizeProxyURL("http://localhost:8080/private")
|
||||
if err == nil {
|
||||
t.Fatal("expected localhost URL to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProxyURLRejectsPrivateIP(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := normalizeProxyURL("http://192.168.1.10/stream")
|
||||
if err == nil {
|
||||
t.Fatal("expected private IP URL to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyTokenScopeValidation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
service := NewService(&fakeProxyQuerier{}, nil, Config{ProxyTokenSecret: "0123456789abcdef0123456789abcdef"})
|
||||
token, err := service.issueProxyToken("https://example.com/playlist.m3u8", "", proxyScopeStream)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to issue token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = service.resolveProxyToken(context.Background(), token, proxyScopeSegment)
|
||||
if err == nil {
|
||||
t.Fatal("expected scope mismatch error")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeProxyQuerier struct {
|
||||
database.Querier
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
391
internal/features/playback/service_base.go
Normal file
391
internal/features/playback/service_base.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mal/internal/database"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
showResolutionCacheTTL = 12 * time.Hour
|
||||
playbackDataCacheTTL = 2 * time.Minute
|
||||
providerProbeTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
allAnimeClient *allAnimeClient
|
||||
httpClient *http.Client
|
||||
sqlDB *sql.DB
|
||||
db database.Querier
|
||||
proxyTokens *proxyTokenSigner
|
||||
proxyHostMu sync.RWMutex
|
||||
proxyHostCache map[string]proxyHostCacheItem
|
||||
|
||||
cacheMu sync.RWMutex
|
||||
showResolution map[int]showResolutionCacheItem
|
||||
playbackDataCache map[string]playbackDataCacheItem
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ProxyTokenSecret string
|
||||
}
|
||||
|
||||
type sourceScore struct {
|
||||
source StreamSource
|
||||
total int
|
||||
typeScore int
|
||||
providerScore int
|
||||
qualityScore int
|
||||
refererScore int
|
||||
}
|
||||
|
||||
type showResolutionCacheItem struct {
|
||||
ShowID string
|
||||
Title string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type playbackDataCacheItem struct {
|
||||
Data playbackBaseData
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type playbackBaseData struct {
|
||||
Title string
|
||||
AvailableModes []string
|
||||
ModeSources map[string]ModeSource
|
||||
Segments []SkipSegment
|
||||
}
|
||||
|
||||
type modeSourceResult struct {
|
||||
Mode string
|
||||
Source ModeSource
|
||||
OK bool
|
||||
}
|
||||
|
||||
type searchModeResult struct {
|
||||
Mode string
|
||||
Results []searchResult
|
||||
Err error
|
||||
}
|
||||
|
||||
type directProbeResult struct {
|
||||
Playable bool
|
||||
ContentType string
|
||||
}
|
||||
|
||||
type proxyHostCacheItem struct {
|
||||
Allowed bool
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type userPlaybackState struct {
|
||||
CurrentStatus string
|
||||
StartTimeSeconds float64
|
||||
}
|
||||
|
||||
func NewService(db database.Querier, sqlDB *sql.DB, cfg Config) *Service {
|
||||
proxyTokens, err := newProxyTokenSigner(cfg.ProxyTokenSecret)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to initialize proxy token signer: %v", err))
|
||||
}
|
||||
|
||||
return &Service{
|
||||
allAnimeClient: newAllAnimeClient(),
|
||||
httpClient: &http.Client{Timeout: 12 * time.Second},
|
||||
sqlDB: sqlDB,
|
||||
db: db,
|
||||
proxyTokens: proxyTokens,
|
||||
proxyHostCache: make(map[string]proxyHostCacheItem),
|
||||
showResolution: make(map[int]showResolutionCacheItem),
|
||||
playbackDataCache: make(map[string]playbackDataCacheItem),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) BuildWatchPageData(ctx context.Context, malID int, titleCandidates []string, episode string, mode string, userID string) (WatchPageData, error) {
|
||||
if malID <= 0 {
|
||||
return WatchPageData{}, errors.New("invalid mal id")
|
||||
}
|
||||
|
||||
normalizedMode := normalizeMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = "dub"
|
||||
}
|
||||
|
||||
normalizedEpisode := strings.TrimSpace(episode)
|
||||
if normalizedEpisode == "" {
|
||||
normalizedEpisode = "1"
|
||||
}
|
||||
|
||||
userStateCh := s.fetchUserPlaybackStateAsync(ctx, userID, malID, normalizedEpisode)
|
||||
|
||||
cacheKey := playbackDataCacheKey(malID, normalizedEpisode)
|
||||
baseData, cacheHit := s.getPlaybackBaseDataCache(cacheKey)
|
||||
if !cacheHit {
|
||||
showID, resolvedTitle, err := s.resolveShowCached(ctx, malID, titleCandidates)
|
||||
if err != nil {
|
||||
return WatchPageData{}, err
|
||||
}
|
||||
|
||||
modeSources, segments := s.fetchPlaybackSourcesAndSegments(ctx, showID, malID, normalizedEpisode)
|
||||
if len(modeSources) == 0 {
|
||||
return WatchPageData{}, errors.New("no direct playable sources available")
|
||||
}
|
||||
|
||||
watchTitle := strings.TrimSpace(resolvedTitle)
|
||||
if watchTitle == "" {
|
||||
watchTitle = firstNonEmptyTitle(titleCandidates)
|
||||
}
|
||||
if watchTitle == "" {
|
||||
watchTitle = fmt.Sprintf("MAL #%d", malID)
|
||||
}
|
||||
|
||||
baseData = playbackBaseData{
|
||||
Title: watchTitle,
|
||||
AvailableModes: availableModes(modeSources),
|
||||
ModeSources: modeSources,
|
||||
Segments: segments,
|
||||
}
|
||||
|
||||
s.setPlaybackBaseDataCache(cacheKey, baseData)
|
||||
}
|
||||
|
||||
initialMode := selectInitialMode(normalizedMode, baseData.ModeSources)
|
||||
|
||||
clientModeSources, err := s.buildClientModeSources(baseData.ModeSources)
|
||||
if err != nil {
|
||||
return WatchPageData{}, err
|
||||
}
|
||||
|
||||
if _, ok := clientModeSources[initialMode]; !ok {
|
||||
return WatchPageData{}, errors.New("stream mode unavailable")
|
||||
}
|
||||
|
||||
userState := userPlaybackState{}
|
||||
if userStateCh != nil {
|
||||
userState = <-userStateCh
|
||||
}
|
||||
|
||||
return WatchPageData{
|
||||
MalID: malID,
|
||||
Title: baseData.Title,
|
||||
CurrentEpisode: normalizedEpisode,
|
||||
StartTimeSeconds: userState.StartTimeSeconds,
|
||||
CurrentStatus: userState.CurrentStatus,
|
||||
InitialMode: initialMode,
|
||||
AvailableModes: cloneStringSlice(baseData.AvailableModes),
|
||||
ModeSources: clientModeSources,
|
||||
Segments: cloneSegments(baseData.Segments),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func playbackDataCacheKey(malID int, episode string) string {
|
||||
return fmt.Sprintf("%d:%s", malID, episode)
|
||||
}
|
||||
|
||||
func (s *Service) fetchUserPlaybackStateAsync(ctx context.Context, userID string, malID int, episode string) <-chan userPlaybackState {
|
||||
if userID == "" || s.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
resultCh := make(chan userPlaybackState, 1)
|
||||
go func() {
|
||||
state := userPlaybackState{}
|
||||
|
||||
entry, err := s.db.GetWatchListEntry(ctx, database.GetWatchListEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: int64(malID),
|
||||
})
|
||||
if err == nil {
|
||||
state.CurrentStatus = entry.Status
|
||||
if entry.CurrentEpisode.Valid && strconv.FormatInt(entry.CurrentEpisode.Int64, 10) == episode && entry.CurrentTimeSeconds > 0 {
|
||||
state.StartTimeSeconds = entry.CurrentTimeSeconds
|
||||
}
|
||||
}
|
||||
|
||||
if state.StartTimeSeconds <= 0 {
|
||||
continueEntry, continueErr := s.db.GetContinueWatchingEntry(ctx, database.GetContinueWatchingEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: int64(malID),
|
||||
})
|
||||
if continueErr == nil && continueEntry.CurrentEpisode.Valid && strconv.FormatInt(continueEntry.CurrentEpisode.Int64, 10) == episode && continueEntry.CurrentTimeSeconds > 0 {
|
||||
state.StartTimeSeconds = continueEntry.CurrentTimeSeconds
|
||||
}
|
||||
}
|
||||
|
||||
resultCh <- state
|
||||
}()
|
||||
|
||||
return resultCh
|
||||
}
|
||||
|
||||
func (s *Service) getPlaybackBaseDataCache(key string) (playbackBaseData, bool) {
|
||||
now := time.Now()
|
||||
|
||||
s.cacheMu.RLock()
|
||||
item, ok := s.playbackDataCache[key]
|
||||
s.cacheMu.RUnlock()
|
||||
if !ok {
|
||||
return playbackBaseData{}, false
|
||||
}
|
||||
|
||||
if now.After(item.ExpiresAt) {
|
||||
s.cacheMu.Lock()
|
||||
current, exists := s.playbackDataCache[key]
|
||||
if exists && time.Now().After(current.ExpiresAt) {
|
||||
delete(s.playbackDataCache, key)
|
||||
}
|
||||
s.cacheMu.Unlock()
|
||||
return playbackBaseData{}, false
|
||||
}
|
||||
|
||||
return clonePlaybackBaseData(item.Data), true
|
||||
}
|
||||
|
||||
func (s *Service) setPlaybackBaseDataCache(key string, data playbackBaseData) {
|
||||
s.cacheMu.Lock()
|
||||
s.playbackDataCache[key] = playbackDataCacheItem{
|
||||
Data: clonePlaybackBaseData(data),
|
||||
ExpiresAt: time.Now().Add(playbackDataCacheTTL),
|
||||
}
|
||||
s.cacheMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Service) resolveShowCached(ctx context.Context, malID int, titleCandidates []string) (string, string, error) {
|
||||
now := time.Now()
|
||||
|
||||
s.cacheMu.RLock()
|
||||
item, ok := s.showResolution[malID]
|
||||
s.cacheMu.RUnlock()
|
||||
|
||||
if ok && now.Before(item.ExpiresAt) && strings.TrimSpace(item.ShowID) != "" {
|
||||
return item.ShowID, item.Title, nil
|
||||
}
|
||||
|
||||
showID, resolvedTitle, err := s.resolveShow(ctx, malID, titleCandidates)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
s.cacheMu.Lock()
|
||||
s.showResolution[malID] = showResolutionCacheItem{
|
||||
ShowID: showID,
|
||||
Title: resolvedTitle,
|
||||
ExpiresAt: time.Now().Add(showResolutionCacheTTL),
|
||||
}
|
||||
s.cacheMu.Unlock()
|
||||
|
||||
return showID, resolvedTitle, nil
|
||||
}
|
||||
|
||||
func (s *Service) fetchPlaybackSourcesAndSegments(ctx context.Context, showID string, malID int, episode string) (map[string]ModeSource, []SkipSegment) {
|
||||
modeCh := make(chan modeSourceResult, 2)
|
||||
probeCache := make(map[string]directProbeResult)
|
||||
probeCacheMu := sync.Mutex{}
|
||||
|
||||
for _, mode := range []string{"dub", "sub"} {
|
||||
modeValue := mode
|
||||
go func() {
|
||||
resolved, err := s.resolveModeSourceWithCache(ctx, showID, episode, modeValue, "best", probeCache, &probeCacheMu)
|
||||
if err != nil {
|
||||
modeCh <- modeSourceResult{Mode: modeValue, OK: false}
|
||||
return
|
||||
}
|
||||
|
||||
if strings.ToLower(resolved.Type) == "embed" {
|
||||
modeCh <- modeSourceResult{Mode: modeValue, OK: false}
|
||||
return
|
||||
}
|
||||
|
||||
modeCh <- modeSourceResult{
|
||||
Mode: modeValue,
|
||||
Source: ModeSource{
|
||||
URL: resolved.URL,
|
||||
Referer: resolved.Referer,
|
||||
Subtitles: toSubtitleItems(resolved),
|
||||
},
|
||||
OK: true,
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
segmentsCh := make(chan []SkipSegment, 1)
|
||||
go func() {
|
||||
segmentsCh <- s.fetchSkipSegments(ctx, malID, episode)
|
||||
}()
|
||||
|
||||
modeSources := make(map[string]ModeSource)
|
||||
for range 2 {
|
||||
result := <-modeCh
|
||||
if !result.OK {
|
||||
continue
|
||||
}
|
||||
modeSources[result.Mode] = result.Source
|
||||
}
|
||||
|
||||
segments := <-segmentsCh
|
||||
return modeSources, segments
|
||||
}
|
||||
|
||||
func clonePlaybackBaseData(data playbackBaseData) playbackBaseData {
|
||||
return playbackBaseData{
|
||||
Title: data.Title,
|
||||
AvailableModes: cloneStringSlice(data.AvailableModes),
|
||||
ModeSources: cloneModeSources(data.ModeSources),
|
||||
Segments: cloneSegments(data.Segments),
|
||||
}
|
||||
}
|
||||
|
||||
func cloneStringSlice(items []string) []string {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]string, len(items))
|
||||
copy(cloned, items)
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneModeSources(modeSources map[string]ModeSource) map[string]ModeSource {
|
||||
if len(modeSources) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string]ModeSource, len(modeSources))
|
||||
for mode, source := range modeSources {
|
||||
cloned[mode] = ModeSource{
|
||||
URL: source.URL,
|
||||
Referer: source.Referer,
|
||||
Subtitles: cloneSubtitleItems(source.Subtitles),
|
||||
}
|
||||
}
|
||||
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneSubtitleItems(items []SubtitleItem) []SubtitleItem {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]SubtitleItem, len(items))
|
||||
copy(cloned, items)
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneSegments(segments []SkipSegment) []SkipSegment {
|
||||
if len(segments) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]SkipSegment, len(segments))
|
||||
copy(cloned, segments)
|
||||
return cloned
|
||||
}
|
||||
77
internal/features/playback/service_http.go
Normal file
77
internal/features/playback/service_http.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Service) fetchSkipSegments(ctx context.Context, malID int, episode string) []SkipSegment {
|
||||
if malID <= 0 || strings.TrimSpace(episode) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("https://api.aniskip.com/v1/skip-times/%s/%s?types=op&types=ed", url.PathEscape(strconv.Itoa(malID)), url.PathEscape(episode))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type resultItem struct {
|
||||
SkipType string `json:"skip_type"`
|
||||
Interval struct {
|
||||
StartTime float64 `json:"start_time"`
|
||||
EndTime float64 `json:"end_time"`
|
||||
} `json:"interval"`
|
||||
}
|
||||
type apiResponse struct {
|
||||
Found bool `json:"found"`
|
||||
Result []resultItem `json:"results"`
|
||||
}
|
||||
|
||||
var parsed apiResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
segments := make([]SkipSegment, 0, len(parsed.Result))
|
||||
for _, item := range parsed.Result {
|
||||
if item.Interval.EndTime <= item.Interval.StartTime {
|
||||
continue
|
||||
}
|
||||
|
||||
t := strings.ToLower(item.SkipType)
|
||||
if t != "op" && t != "ed" {
|
||||
continue
|
||||
}
|
||||
|
||||
segments = append(segments, SkipSegment{
|
||||
Type: t,
|
||||
Start: item.Interval.StartTime,
|
||||
End: item.Interval.EndTime,
|
||||
})
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
75
internal/features/playback/service_proxy.go
Normal file
75
internal/features/playback/service_proxy.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (s *Service) ProxyStream(ctx context.Context, targetURL string, referer string, rangeHeader string) (int, http.Header, []byte, io.ReadCloser, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
|
||||
if err != nil {
|
||||
return 0, nil, nil, nil, fmt.Errorf("invalid upstream url: %w", err)
|
||||
}
|
||||
|
||||
if referer != "" {
|
||||
req.Header.Set("Referer", referer)
|
||||
}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
if rangeHeader != "" {
|
||||
req.Header.Set("Range", rangeHeader)
|
||||
}
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, nil, nil, nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
if isM3U8(targetURL, resp.Header.Get("Content-Type")) {
|
||||
defer resp.Body.Close()
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||
if readErr != nil {
|
||||
return 0, nil, nil, nil, fmt.Errorf("read playlist failed: %w", readErr)
|
||||
}
|
||||
|
||||
rewritten, rewriteErr := s.rewritePlaylistWithTokens(ctx, string(body), targetURL, referer)
|
||||
if rewriteErr != nil {
|
||||
return 0, nil, nil, nil, fmt.Errorf("rewrite playlist failed: %w", rewriteErr)
|
||||
}
|
||||
|
||||
headers := cloneHeaders(resp.Header)
|
||||
headers.Set("Content-Type", "application/vnd.apple.mpegurl")
|
||||
return resp.StatusCode, headers, []byte(rewritten), nil, nil
|
||||
}
|
||||
|
||||
headers := cloneHeaders(resp.Header)
|
||||
return resp.StatusCode, headers, nil, resp.Body, nil
|
||||
}
|
||||
|
||||
func isM3U8(targetURL string, contentType string) bool {
|
||||
lowerURL := strings.ToLower(targetURL)
|
||||
lowerType := strings.ToLower(contentType)
|
||||
if strings.Contains(lowerURL, ".m3u8") {
|
||||
return true
|
||||
}
|
||||
|
||||
return strings.Contains(lowerType, "application/vnd.apple.mpegurl") || strings.Contains(lowerType, "application/x-mpegurl")
|
||||
}
|
||||
|
||||
func cloneHeaders(src http.Header) http.Header {
|
||||
dst := make(http.Header)
|
||||
for key, values := range src {
|
||||
lower := strings.ToLower(key)
|
||||
if lower == "connection" || lower == "transfer-encoding" || lower == "keep-alive" || lower == "proxy-authenticate" || lower == "proxy-authorization" || lower == "te" || lower == "trailers" || lower == "upgrade" {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
216
internal/features/playback/service_ranking.go
Normal file
216
internal/features/playback/service_ranking.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func rankSources(sources []StreamSource, quality string) ([]sourceScore, error) {
|
||||
filtered := make([]StreamSource, 0, len(sources))
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, source := range sources {
|
||||
if source.URL == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[source.URL]; exists {
|
||||
continue
|
||||
}
|
||||
seen[source.URL] = struct{}{}
|
||||
filtered = append(filtered, source)
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
return nil, errors.New("no playable sources available")
|
||||
}
|
||||
|
||||
targetQuality := normalizeQuality(quality)
|
||||
scored := make([]sourceScore, 0, len(filtered))
|
||||
for _, source := range filtered {
|
||||
typeScore := sourceTypePriority(source.Type)
|
||||
providerScore := providerPriority(source.Provider)
|
||||
qualityScore := sourceQualityPriority(source.Quality, targetQuality)
|
||||
refererScore := 0
|
||||
if source.Referer != "" {
|
||||
refererScore = 20
|
||||
}
|
||||
|
||||
total := typeScore + providerScore + qualityScore + refererScore
|
||||
scored = append(scored, sourceScore{
|
||||
source: source,
|
||||
total: total,
|
||||
typeScore: typeScore,
|
||||
providerScore: providerScore,
|
||||
qualityScore: qualityScore,
|
||||
refererScore: refererScore,
|
||||
})
|
||||
}
|
||||
|
||||
sort.SliceStable(scored, func(i int, j int) bool {
|
||||
return scored[i].total > scored[j].total
|
||||
})
|
||||
|
||||
return scored, nil
|
||||
}
|
||||
|
||||
func normalizeQuality(quality string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(quality))
|
||||
if lower == "" {
|
||||
return "best"
|
||||
}
|
||||
|
||||
return lower
|
||||
}
|
||||
|
||||
func sourceTypePriority(sourceType string) int {
|
||||
switch strings.ToLower(sourceType) {
|
||||
case "mp4":
|
||||
return 500
|
||||
case "m3u8":
|
||||
return 450
|
||||
case "unknown":
|
||||
return 300
|
||||
case "embed":
|
||||
return 100
|
||||
default:
|
||||
return 200
|
||||
}
|
||||
}
|
||||
|
||||
func providerPriority(provider string) int {
|
||||
switch strings.ToLower(provider) {
|
||||
case "s-mp4":
|
||||
return 120
|
||||
case "default":
|
||||
return 115
|
||||
case "luf-mp4":
|
||||
return 110
|
||||
case "vid-mp4":
|
||||
return 105
|
||||
case "yt-mp4":
|
||||
return 100
|
||||
case "mp4":
|
||||
return 95
|
||||
case "uv-mp4":
|
||||
return 90
|
||||
case "hls":
|
||||
return 80
|
||||
case "sw":
|
||||
return 40
|
||||
case "ok":
|
||||
return 35
|
||||
case "ss-hls":
|
||||
return 30
|
||||
default:
|
||||
return 60
|
||||
}
|
||||
}
|
||||
|
||||
func sourceQualityPriority(sourceQuality string, targetQuality string) int {
|
||||
qualityValue := parseQualityValue(sourceQuality)
|
||||
|
||||
switch targetQuality {
|
||||
case "best":
|
||||
return qualityValue
|
||||
case "worst":
|
||||
return -qualityValue
|
||||
default:
|
||||
if qualityMatches(sourceQuality, targetQuality) {
|
||||
return 2000 + qualityValue
|
||||
}
|
||||
|
||||
return -300 + qualityValue
|
||||
}
|
||||
}
|
||||
|
||||
func parseQualityValue(rawQuality string) int {
|
||||
lower := strings.ToLower(rawQuality)
|
||||
var digits strings.Builder
|
||||
|
||||
for _, char := range lower {
|
||||
if char >= '0' && char <= '9' {
|
||||
digits.WriteRune(char)
|
||||
continue
|
||||
}
|
||||
if digits.Len() > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if digits.Len() > 0 {
|
||||
value, err := strconv.Atoi(digits.String())
|
||||
if err == nil {
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
if lower == "auto" {
|
||||
return 240
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func qualityMatches(sourceQuality string, targetQuality string) bool {
|
||||
sourceLower := strings.ToLower(sourceQuality)
|
||||
targetLower := strings.ToLower(targetQuality)
|
||||
|
||||
if sourceLower == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.Contains(sourceLower, targetLower) {
|
||||
return true
|
||||
}
|
||||
|
||||
sourceDigits := extractDigits(sourceLower)
|
||||
targetDigits := extractDigits(targetLower)
|
||||
|
||||
return sourceDigits != "" && sourceDigits == targetDigits
|
||||
}
|
||||
|
||||
func extractDigits(value string) string {
|
||||
var digits strings.Builder
|
||||
for _, char := range value {
|
||||
if char >= '0' && char <= '9' {
|
||||
digits.WriteRune(char)
|
||||
continue
|
||||
}
|
||||
if digits.Len() > 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return digits.String()
|
||||
}
|
||||
|
||||
func normalizeSourceTypeFromProbe(source StreamSource, contentType string) StreamSource {
|
||||
lower := strings.ToLower(contentType)
|
||||
if strings.Contains(lower, "video/mp4") {
|
||||
source.Type = "mp4"
|
||||
return source
|
||||
}
|
||||
|
||||
if strings.Contains(lower, "mpegurl") {
|
||||
source.Type = "m3u8"
|
||||
return source
|
||||
}
|
||||
|
||||
return source
|
||||
}
|
||||
|
||||
func isLikelyMP4(payload []byte) bool {
|
||||
if len(payload) < 12 {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(payload[4:8], []byte("ftyp"))
|
||||
}
|
||||
|
||||
func isLikelyM3U8(payload []byte) bool {
|
||||
trimmed := strings.TrimSpace(string(payload))
|
||||
return strings.HasPrefix(trimmed, "#EXTM3U")
|
||||
}
|
||||
174
internal/features/playback/service_resolution.go
Normal file
174
internal/features/playback/service_resolution.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func (s *Service) resolveShow(ctx context.Context, malID int, titleCandidates []string) (string, string, error) {
|
||||
malText := strconv.Itoa(malID)
|
||||
modeCandidates := []string{"sub", "dub"}
|
||||
queries := buildTitleSearchQueries(titleCandidates)
|
||||
|
||||
for _, query := range queries {
|
||||
resultsByMode := s.searchShowResultsByMode(ctx, query, modeCandidates)
|
||||
|
||||
for _, mode := range modeCandidates {
|
||||
for _, result := range resultsByMode[mode] {
|
||||
if strings.TrimSpace(result.MalID) == malText && strings.TrimSpace(result.ID) != "" {
|
||||
return result.ID, result.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, mode := range modeCandidates {
|
||||
results := resultsByMode[mode]
|
||||
if len(results) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
best := results[0]
|
||||
if strings.TrimSpace(best.ID) != "" {
|
||||
return best.ID, best.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", errors.New("unable to resolve allanime show")
|
||||
}
|
||||
|
||||
func (s *Service) searchShowResultsByMode(ctx context.Context, query string, modeCandidates []string) map[string][]searchResult {
|
||||
resultsByMode := make(map[string][]searchResult, len(modeCandidates))
|
||||
searchCh := make(chan searchModeResult, len(modeCandidates))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mode := range modeCandidates {
|
||||
modeValue := mode
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
results, err := s.allAnimeClient.Search(ctx, query, modeValue)
|
||||
searchCh <- searchModeResult{Mode: modeValue, Results: results, Err: err}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(searchCh)
|
||||
|
||||
for result := range searchCh {
|
||||
if result.Err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
resultsByMode[result.Mode] = result.Results
|
||||
}
|
||||
|
||||
return resultsByMode
|
||||
}
|
||||
|
||||
func buildTitleSearchQueries(titleCandidates []string) []string {
|
||||
queries := make([]string, 0, len(titleCandidates)*4)
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
add := func(raw string) {
|
||||
normalized := normalizeSearchQuery(raw)
|
||||
if normalized == "" {
|
||||
return
|
||||
}
|
||||
|
||||
key := strings.ToLower(normalized)
|
||||
if _, exists := seen[key]; exists {
|
||||
return
|
||||
}
|
||||
|
||||
seen[key] = struct{}{}
|
||||
queries = append(queries, normalized)
|
||||
}
|
||||
|
||||
for _, candidate := range titleCandidates {
|
||||
normalized := normalizeSearchQuery(candidate)
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
add(normalized)
|
||||
add(strings.ReplaceAll(normalized, "+", " "))
|
||||
|
||||
withoutApostrophes := strings.NewReplacer("'", "", "’", "", "`", "").Replace(normalized)
|
||||
add(withoutApostrophes)
|
||||
add(strings.ReplaceAll(withoutApostrophes, "+", " "))
|
||||
}
|
||||
|
||||
return queries
|
||||
}
|
||||
|
||||
func normalizeSearchQuery(raw string) string {
|
||||
return strings.Join(strings.Fields(strings.TrimSpace(raw)), " ")
|
||||
}
|
||||
|
||||
func firstNonEmptyTitle(values []string) string {
|
||||
for _, value := range values {
|
||||
normalized := strings.TrimSpace(value)
|
||||
if normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeMode(raw string) string {
|
||||
lower := strings.ToLower(strings.TrimSpace(raw))
|
||||
if lower == "sub" || lower == "dub" {
|
||||
return lower
|
||||
}
|
||||
|
||||
return lower
|
||||
}
|
||||
|
||||
func availableModes(modeSources map[string]ModeSource) []string {
|
||||
ordered := make([]string, 0, len(modeSources))
|
||||
if _, ok := modeSources["dub"]; ok {
|
||||
ordered = append(ordered, "dub")
|
||||
}
|
||||
if _, ok := modeSources["sub"]; ok {
|
||||
ordered = append(ordered, "sub")
|
||||
}
|
||||
|
||||
extra := make([]string, 0)
|
||||
for mode := range modeSources {
|
||||
if mode == "dub" || mode == "sub" {
|
||||
continue
|
||||
}
|
||||
extra = append(extra, mode)
|
||||
}
|
||||
sort.Strings(extra)
|
||||
|
||||
return append(ordered, extra...)
|
||||
}
|
||||
|
||||
func selectInitialMode(requestedMode string, modeSources map[string]ModeSource) string {
|
||||
normalizedRequested := normalizeMode(requestedMode)
|
||||
if normalizedRequested != "" {
|
||||
if _, ok := modeSources[normalizedRequested]; ok {
|
||||
return normalizedRequested
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := modeSources["dub"]; ok {
|
||||
return "dub"
|
||||
}
|
||||
if _, ok := modeSources["sub"]; ok {
|
||||
return "sub"
|
||||
}
|
||||
|
||||
for mode := range modeSources {
|
||||
return mode
|
||||
}
|
||||
|
||||
return "dub"
|
||||
}
|
||||
257
internal/features/playback/service_sources.go
Normal file
257
internal/features/playback/service_sources.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func (s *Service) resolveModeSource(ctx context.Context, showID string, episode string, mode string, quality string) (StreamSource, error) {
|
||||
sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode)
|
||||
if err != nil {
|
||||
return StreamSource{}, err
|
||||
}
|
||||
|
||||
ranked, err := rankSources(sources, quality)
|
||||
if err != nil {
|
||||
return StreamSource{}, err
|
||||
}
|
||||
|
||||
selected, _, err := s.choosePlaybackSource(ctx, ranked)
|
||||
if err != nil {
|
||||
return StreamSource{}, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveModeSourceWithCache(
|
||||
ctx context.Context,
|
||||
showID string,
|
||||
episode string,
|
||||
mode string,
|
||||
quality string,
|
||||
probeCache map[string]directProbeResult,
|
||||
probeCacheMu *sync.Mutex,
|
||||
) (StreamSource, error) {
|
||||
sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode)
|
||||
if err != nil {
|
||||
return StreamSource{}, err
|
||||
}
|
||||
|
||||
ranked, err := rankSources(sources, quality)
|
||||
if err != nil {
|
||||
return StreamSource{}, err
|
||||
}
|
||||
|
||||
selected, _, err := s.choosePlaybackSourceWithCache(ctx, ranked, probeCache, probeCacheMu)
|
||||
if err != nil {
|
||||
return StreamSource{}, err
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func (s *Service) choosePlaybackSource(ctx context.Context, ranked []sourceScore) (StreamSource, string, error) {
|
||||
if len(ranked) == 0 {
|
||||
return StreamSource{}, "", errors.New("no ranked sources available")
|
||||
}
|
||||
|
||||
embedCandidates := make([]StreamSource, 0)
|
||||
for _, candidate := range ranked {
|
||||
source := candidate.source
|
||||
sourceType := strings.ToLower(source.Type)
|
||||
|
||||
switch sourceType {
|
||||
case "mp4", "m3u8":
|
||||
return source, "direct-media", nil
|
||||
case "embed":
|
||||
embedCandidates = append(embedCandidates, source)
|
||||
default:
|
||||
playable, contentType := s.probeDirectMedia(ctx, source)
|
||||
if playable {
|
||||
return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, embed := range embedCandidates {
|
||||
if s.probeEmbedSource(ctx, embed) {
|
||||
return embed, "embed-probed", nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(embedCandidates) > 0 {
|
||||
return embedCandidates[0], "embed-fallback", nil
|
||||
}
|
||||
|
||||
return ranked[0].source, "ranked-fallback", nil
|
||||
}
|
||||
|
||||
func (s *Service) choosePlaybackSourceWithCache(
|
||||
ctx context.Context,
|
||||
ranked []sourceScore,
|
||||
probeCache map[string]directProbeResult,
|
||||
probeCacheMu *sync.Mutex,
|
||||
) (StreamSource, string, error) {
|
||||
if len(ranked) == 0 {
|
||||
return StreamSource{}, "", errors.New("no ranked sources available")
|
||||
}
|
||||
|
||||
embedCandidates := make([]StreamSource, 0)
|
||||
for _, candidate := range ranked {
|
||||
source := candidate.source
|
||||
sourceType := strings.ToLower(source.Type)
|
||||
|
||||
switch sourceType {
|
||||
case "mp4", "m3u8":
|
||||
return source, "direct-media", nil
|
||||
case "embed":
|
||||
embedCandidates = append(embedCandidates, source)
|
||||
default:
|
||||
playable, contentType := s.probeDirectMediaCached(ctx, source, probeCache, probeCacheMu)
|
||||
if playable {
|
||||
return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, embed := range embedCandidates {
|
||||
if s.probeEmbedSource(ctx, embed) {
|
||||
return embed, "embed-probed", nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(embedCandidates) > 0 {
|
||||
return embedCandidates[0], "embed-fallback", nil
|
||||
}
|
||||
|
||||
return ranked[0].source, "ranked-fallback", nil
|
||||
}
|
||||
|
||||
func (s *Service) probeDirectMediaCached(
|
||||
ctx context.Context,
|
||||
source StreamSource,
|
||||
probeCache map[string]directProbeResult,
|
||||
probeCacheMu *sync.Mutex,
|
||||
) (bool, string) {
|
||||
cacheKey := strings.TrimSpace(source.URL)
|
||||
if cacheKey == "" {
|
||||
return s.probeDirectMedia(ctx, source)
|
||||
}
|
||||
|
||||
probeCacheMu.Lock()
|
||||
cached, ok := probeCache[cacheKey]
|
||||
probeCacheMu.Unlock()
|
||||
if ok {
|
||||
return cached.Playable, cached.ContentType
|
||||
}
|
||||
|
||||
playable, contentType := s.probeDirectMedia(ctx, source)
|
||||
|
||||
probeCacheMu.Lock()
|
||||
probeCache[cacheKey] = directProbeResult{Playable: playable, ContentType: contentType}
|
||||
probeCacheMu.Unlock()
|
||||
|
||||
return playable, contentType
|
||||
}
|
||||
|
||||
func (s *Service) probeDirectMedia(ctx context.Context, source StreamSource) (bool, string) {
|
||||
probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if source.Referer != "" {
|
||||
req.Header.Set("Referer", source.Referer)
|
||||
}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
req.Header.Set("Range", "bytes=0-4095")
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false, ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||
if strings.Contains(contentType, "video/") || strings.Contains(contentType, "mpegurl") {
|
||||
return true, contentType
|
||||
}
|
||||
|
||||
prefix, err := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
if err == nil {
|
||||
if isLikelyM3U8(prefix) {
|
||||
return true, "application/vnd.apple.mpegurl"
|
||||
}
|
||||
if isLikelyMP4(prefix) {
|
||||
return true, "video/mp4"
|
||||
}
|
||||
}
|
||||
|
||||
finalURL := ""
|
||||
if resp.Request != nil && resp.Request.URL != nil {
|
||||
finalURL = strings.ToLower(resp.Request.URL.String())
|
||||
}
|
||||
|
||||
if strings.Contains(finalURL, ".mp4") || strings.Contains(finalURL, ".m3u8") {
|
||||
return true, contentType
|
||||
}
|
||||
|
||||
return false, contentType
|
||||
}
|
||||
|
||||
func (s *Service) probeEmbedSource(ctx context.Context, source StreamSource) bool {
|
||||
probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if source.Referer != "" {
|
||||
req.Header.Set("Referer", source.Referer)
|
||||
}
|
||||
req.Header.Set("User-Agent", defaultUserAgent)
|
||||
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
content := strings.ToLower(string(body))
|
||||
markers := []string{
|
||||
"file was deleted",
|
||||
"file has been deleted",
|
||||
"video was deleted",
|
||||
"video has been deleted",
|
||||
"video unavailable",
|
||||
"file not found",
|
||||
"this file does not exist",
|
||||
"resource unavailable",
|
||||
}
|
||||
for _, marker := range markers {
|
||||
if strings.Contains(content, marker) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
23
internal/features/playback/service_utils.go
Normal file
23
internal/features/playback/service_utils.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package playback
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func toSubtitleItems(source StreamSource) []SubtitleItem {
|
||||
items := make([]SubtitleItem, 0, len(source.Subtitles))
|
||||
for _, subtitle := range source.Subtitles {
|
||||
targetURL := strings.TrimSpace(subtitle.URL)
|
||||
if targetURL == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
items = append(items, SubtitleItem{
|
||||
Lang: strings.TrimSpace(subtitle.Lang),
|
||||
URL: targetURL,
|
||||
Referer: source.Referer,
|
||||
})
|
||||
}
|
||||
|
||||
return items
|
||||
}
|
||||
@@ -15,15 +15,17 @@ type Subtitle struct {
|
||||
}
|
||||
|
||||
type ModeSource struct {
|
||||
URL string `json:"url"`
|
||||
Referer string `json:"referer"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Referer string `json:"referer,omitempty"`
|
||||
Token string `json:"token"`
|
||||
Subtitles []SubtitleItem `json:"subtitles"`
|
||||
}
|
||||
|
||||
type SubtitleItem struct {
|
||||
Lang string `json:"lang"`
|
||||
URL string `json:"url"`
|
||||
Referer string `json:"referer"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Referer string `json:"referer,omitempty"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type SkipSegment struct {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package watchlist
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
@@ -202,7 +201,7 @@ func (h *Handler) HandleContinueWatching(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := h.svc.db.GetContinueWatchingEntries(r.Context(), user.ID)
|
||||
entries, err := h.svc.GetContinueWatching(r.Context(), user.ID)
|
||||
if err != nil {
|
||||
log.Printf("continue watching fetch failed: user_id=%s err=%v", user.ID, err)
|
||||
http.Error(w, "failed to fetch continue watching", http.StatusInternalServerError)
|
||||
@@ -231,27 +230,12 @@ func (h *Handler) HandleDeleteContinueWatching(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
err = h.svc.db.DeleteContinueWatchingEntry(r.Context(), database.DeleteContinueWatchingEntryParams{
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
})
|
||||
if err != nil {
|
||||
if err := h.svc.DeleteContinueWatching(r.Context(), user.ID, animeID); err != nil {
|
||||
log.Printf("continue watching delete failed: user_id=%s anime_id=%d err=%v", user.ID, animeID, err)
|
||||
http.Error(w, "failed to delete continue watching entry", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.svc.db.SaveWatchProgress(r.Context(), database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Valid: false},
|
||||
CurrentTimeSeconds: 0,
|
||||
UserID: user.ID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
log.Printf("continue watching delete failed to clear watchlist progress: user_id=%s anime_id=%d err=%v", user.ID, animeID, err)
|
||||
http.Error(w, "failed to delete continue watching entry", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
|
||||
type Service struct {
|
||||
db database.Querier
|
||||
sqlDB *sql.DB
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -29,8 +31,8 @@ var validStatuses = map[string]struct{}{
|
||||
"plan_to_watch": {},
|
||||
}
|
||||
|
||||
func NewService(db database.Querier) *Service {
|
||||
return &Service{db: db}
|
||||
func NewService(db database.Querier, sqlDB *sql.DB) *Service {
|
||||
return &Service{db: db, sqlDB: sqlDB}
|
||||
}
|
||||
|
||||
type AddRequest struct {
|
||||
@@ -109,6 +111,78 @@ func (s *Service) GetUserWatchlist(ctx context.Context, userID string) ([]databa
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetContinueWatching(ctx context.Context, userID string) ([]database.GetContinueWatchingEntriesRow, error) {
|
||||
if strings.TrimSpace(userID) == "" {
|
||||
return nil, errors.New("invalid user id")
|
||||
}
|
||||
|
||||
entries, err := s.db.GetContinueWatchingEntries(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch continue watching: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error {
|
||||
if strings.TrimSpace(userID) == "" {
|
||||
return errors.New("invalid user id")
|
||||
}
|
||||
|
||||
if animeID <= 0 {
|
||||
return ErrInvalidAnimeID
|
||||
}
|
||||
|
||||
if s.sqlDB == nil {
|
||||
if err := s.db.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to delete continue watching entry: %w", err)
|
||||
}
|
||||
|
||||
if err := s.db.SaveWatchProgress(ctx, database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Valid: false},
|
||||
CurrentTimeSeconds: 0,
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to clear watchlist progress: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.sqlDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
txQueries := database.New(tx)
|
||||
if err := txQueries.DeleteContinueWatchingEntry(ctx, database.DeleteContinueWatchingEntryParams{
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to delete continue watching entry: %w", err)
|
||||
}
|
||||
|
||||
if err := txQueries.SaveWatchProgress(ctx, database.SaveWatchProgressParams{
|
||||
CurrentEpisode: sql.NullInt64{Valid: false},
|
||||
CurrentTimeSeconds: 0,
|
||||
UserID: userID,
|
||||
AnimeID: animeID,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to clear watchlist progress: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit continue watching deletion: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ExportEntry struct {
|
||||
AnimeID int64 `json:"anime_id"`
|
||||
Title string `json:"title"`
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestAddEntry_RejectsInvalidAnimeID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &fakeQuerier{}
|
||||
svc := NewService(q)
|
||||
svc := NewService(q, nil)
|
||||
|
||||
err := svc.AddEntry(context.Background(), "user-1", AddRequest{
|
||||
AnimeID: 0,
|
||||
@@ -54,7 +54,7 @@ func TestAddEntry_RejectsInvalidStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &fakeQuerier{}
|
||||
svc := NewService(q)
|
||||
svc := NewService(q, nil)
|
||||
|
||||
err := svc.AddEntry(context.Background(), "user-1", AddRequest{
|
||||
AnimeID: 1,
|
||||
@@ -101,7 +101,7 @@ func TestExport_UsesDisplayTitleFallbackOrder(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
svc := NewService(q)
|
||||
svc := NewService(q, nil)
|
||||
export, err := svc.Export(context.Background(), "user-1")
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
|
||||
@@ -40,7 +40,7 @@ func (c *Client) GetAnimeByProducer(ctx context.Context, producerID int, page in
|
||||
}
|
||||
|
||||
var stale StudioAnimeResult
|
||||
hasStale := c.getStaleCache(ctx, cacheKey, &cached)
|
||||
hasStale := c.getStaleCache(ctx, cacheKey, &stale)
|
||||
|
||||
var result SearchResponse
|
||||
reqURL := fmt.Sprintf("%s/anime?producers=%d&page=%d", c.baseURL, producerID, page)
|
||||
@@ -84,7 +84,7 @@ func (c *Client) GetProducerByID(ctx context.Context, producerID int) (ProducerR
|
||||
}
|
||||
|
||||
var stale ProducerResponse
|
||||
hasStale := c.getStaleCache(ctx, cacheKey, &cached)
|
||||
hasStale := c.getStaleCache(ctx, cacheKey, &stale)
|
||||
|
||||
var result ProducerResponse
|
||||
reqURL := fmt.Sprintf("%s/producers/%d/full", c.baseURL, producerID)
|
||||
|
||||
96
internal/jikan/studio_test.go
Normal file
96
internal/jikan/studio_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package jikan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"mal/internal/database"
|
||||
)
|
||||
|
||||
type staleCacheQuerier struct {
|
||||
database.Querier
|
||||
staleJSON string
|
||||
}
|
||||
|
||||
func (q *staleCacheQuerier) GetJikanCache(ctx context.Context, key string) (string, error) {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *staleCacheQuerier) GetJikanCacheStale(ctx context.Context, key string) (string, error) {
|
||||
if q.staleJSON == "" {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
|
||||
return q.staleJSON, nil
|
||||
}
|
||||
|
||||
func TestGetProducerByID_UsesStaleCacheOnFetchFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &staleCacheQuerier{
|
||||
staleJSON: `{"data":{"mal_id":7,"about":"stale about"}}`,
|
||||
}
|
||||
|
||||
client := NewClient(q)
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
client.baseURL = testServer.URL
|
||||
client.httpClient = testServer.Client()
|
||||
|
||||
result, err := client.GetProducerByID(context.Background(), 7)
|
||||
if err != nil {
|
||||
t.Fatalf("expected stale cache result, got error: %v", err)
|
||||
}
|
||||
|
||||
if result.Data.MalID != 7 {
|
||||
t.Fatalf("expected stale mal_id 7, got %d", result.Data.MalID)
|
||||
}
|
||||
|
||||
if result.Data.About != "stale about" {
|
||||
t.Fatalf("expected stale about field, got %q", result.Data.About)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAnimeByProducer_UsesStaleCacheOnFetchFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
q := &staleCacheQuerier{
|
||||
staleJSON: `{"Animes":[{"mal_id":42,"title":"Stale Anime"}],"HasNextPage":true,"StudioName":"Stale Studio"}`,
|
||||
}
|
||||
|
||||
client := NewClient(q)
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
client.baseURL = testServer.URL
|
||||
client.httpClient = testServer.Client()
|
||||
|
||||
result, err := client.GetAnimeByProducer(context.Background(), 9, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("expected stale cache result, got error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Animes) != 1 {
|
||||
t.Fatalf("expected one stale anime, got %d", len(result.Animes))
|
||||
}
|
||||
|
||||
if result.Animes[0].MalID != 42 {
|
||||
t.Fatalf("expected stale anime mal_id 42, got %d", result.Animes[0].MalID)
|
||||
}
|
||||
|
||||
if !result.HasNextPage {
|
||||
t.Fatal("expected stale has_next_page=true")
|
||||
}
|
||||
|
||||
if result.StudioName != "Stale Studio" {
|
||||
t.Fatalf("expected stale studio name, got %q", result.StudioName)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/http"
|
||||
|
||||
"mal/internal/database"
|
||||
@@ -14,8 +15,10 @@ import (
|
||||
|
||||
type Config struct {
|
||||
DB *database.Queries
|
||||
SQLDB *sql.DB
|
||||
JikanClient *jikan.Client
|
||||
AuthService *auth.Service
|
||||
PlaybackProxySecret string
|
||||
}
|
||||
|
||||
func NewRouter(cfg Config) http.Handler {
|
||||
@@ -23,12 +26,12 @@ func NewRouter(cfg Config) http.Handler {
|
||||
|
||||
authHandler := auth.NewHandler(cfg.AuthService)
|
||||
|
||||
watchlistSvc := watchlist.NewService(cfg.DB)
|
||||
watchlistSvc := watchlist.NewService(cfg.DB, cfg.SQLDB)
|
||||
watchlistHandler := watchlist.NewHandler(watchlistSvc)
|
||||
|
||||
animeSvc := anime.NewService(cfg.JikanClient, cfg.DB)
|
||||
animeHandler := anime.NewHandler(animeSvc)
|
||||
playbackSvc := playback.NewService(cfg.DB)
|
||||
playbackSvc := playback.NewService(cfg.DB, cfg.SQLDB, playback.Config{ProxyTokenSecret: cfg.PlaybackProxySecret})
|
||||
playbackHandler := playback.NewHandler(playbackSvc, cfg.JikanClient)
|
||||
|
||||
// Serve static files
|
||||
@@ -83,5 +86,6 @@ func NewRouter(cfg Config) http.Handler {
|
||||
// Wrap mux with global CSRF origin verification and auth checking,
|
||||
// THEN auth context parsing.
|
||||
protectedHandler := middleware.RequireGlobalAuth(middleware.VerifyOrigin(mux))
|
||||
return middleware.Auth(cfg.AuthService)(protectedHandler)
|
||||
authenticatedHandler := middleware.Auth(cfg.AuthService)(protectedHandler)
|
||||
return middleware.RequestLogger(authenticatedHandler)
|
||||
}
|
||||
|
||||
@@ -29,16 +29,14 @@ type WatchPageData struct {
|
||||
|
||||
// ModeSource represents a stream source for a specific mode (dub/sub)
|
||||
type ModeSource struct {
|
||||
URL string `json:"url"`
|
||||
Referer string `json:"referer"`
|
||||
Token string `json:"token"`
|
||||
Subtitles []SubtitleItem `json:"subtitles"`
|
||||
}
|
||||
|
||||
// SubtitleItem represents a subtitle track
|
||||
type SubtitleItem struct {
|
||||
Lang string `json:"lang"`
|
||||
URL string `json:"url"`
|
||||
Referer string `json:"referer"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// SkipSegment represents a skippable segment (intro/outro)
|
||||
@@ -199,7 +197,7 @@ templ EpisodeItem(episode jikan.Episode, currentEpisode string, animeID int) {
|
||||
}
|
||||
|
||||
templ VideoPlayer(data WatchPageData) {
|
||||
{{ streamURL := buildStreamURL(data.InitialMode, data.ModeSources) }}
|
||||
{{ streamToken := modeToken(data.InitialMode, data.ModeSources) }}
|
||||
{{ hasDub := modeAvailable(data.AvailableModes, "dub") }}
|
||||
{{ hasSub := modeAvailable(data.AvailableModes, "sub") }}
|
||||
<div
|
||||
@@ -216,6 +214,7 @@ templ VideoPlayer(data WatchPageData) {
|
||||
data-anime-airing={ fmt.Sprintf("%v", data.Airing) }
|
||||
data-start-time-seconds={ fmt.Sprintf("%.3f", data.StartTimeSeconds) }
|
||||
data-initial-mode={ data.InitialMode }
|
||||
data-stream-token={ streamToken }
|
||||
data-available-modes={ toJSON(data.AvailableModes) }
|
||||
data-mode-sources={ toJSON(data.ModeSources) }
|
||||
data-segments={ toJSON(data.Segments) }
|
||||
@@ -226,7 +225,7 @@ templ VideoPlayer(data WatchPageData) {
|
||||
preload="metadata"
|
||||
crossorigin="anonymous"
|
||||
playsinline
|
||||
src={ streamURL }
|
||||
src={ buildStreamURL(data.InitialMode, streamToken) }
|
||||
></video>
|
||||
<div data-loading class="absolute inset-0 flex items-center justify-center bg-black/50">
|
||||
<div class="h-8 w-8 animate-spin border-2 border-(--panel-soft) border-t-(--accent)"></div>
|
||||
@@ -339,9 +338,35 @@ templ VideoPlayer(data WatchPageData) {
|
||||
</div>
|
||||
}
|
||||
|
||||
func buildStreamURL(mode string, modeSources map[string]ModeSource) string {
|
||||
stateJSON, _ := json.Marshal(modeSources)
|
||||
return fmt.Sprintf("/watch/proxy/stream?mode=%s&state=%s", url.QueryEscape(mode), url.QueryEscape(string(stateJSON)))
|
||||
func buildStreamURL(mode string, token string) string {
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("/watch/proxy/stream?mode=%s&token=%s", url.QueryEscape(mode), url.QueryEscape(token))
|
||||
}
|
||||
|
||||
func modeToken(mode string, modeSources map[string]ModeSource) string {
|
||||
normalizedMode := mode
|
||||
if _, ok := modeSources[normalizedMode]; !ok {
|
||||
if _, ok := modeSources["dub"]; ok {
|
||||
normalizedMode = "dub"
|
||||
} else if _, ok := modeSources["sub"]; ok {
|
||||
normalizedMode = "sub"
|
||||
} else {
|
||||
for key := range modeSources {
|
||||
normalizedMode = key
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
source, ok := modeSources[normalizedMode]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return source.Token
|
||||
}
|
||||
|
||||
func toJSON(v interface{}) string {
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
export {}
|
||||
|
||||
interface ModeSource {
|
||||
url: string
|
||||
referer: string
|
||||
token: string
|
||||
subtitles: SubtitleItem[]
|
||||
}
|
||||
|
||||
interface SubtitleItem {
|
||||
lang: string
|
||||
url: string
|
||||
referer: string
|
||||
token: string
|
||||
}
|
||||
|
||||
interface SkipSegment {
|
||||
@@ -46,6 +44,7 @@ const initPlayer = (): void => {
|
||||
const subtitleText = container.querySelector('[data-subtitle-text]') as HTMLElement
|
||||
|
||||
const streamURL = container.getAttribute('data-stream-url') || '/watch/proxy/stream'
|
||||
const initialStreamToken = container.getAttribute('data-stream-token') || ''
|
||||
const currentEpisode = container.getAttribute('data-current-episode') || '1'
|
||||
const malID = Number.parseInt(container.getAttribute('data-mal-id') || '', 10)
|
||||
const totalEpisodes = Number.parseInt(container.getAttribute('data-total-episodes') || '0', 10)
|
||||
@@ -79,6 +78,10 @@ const initPlayer = (): void => {
|
||||
.sort((a: { start: number }, b: { start: number }) => a.start - b.start)
|
||||
|
||||
let currentMode = availableModes.includes(initialMode) ? initialMode : (availableModes[0] || 'dub')
|
||||
const fallbackMode = Object.keys(modeSources).find((mode) => typeof modeSources[mode]?.token === 'string' && modeSources[mode].token !== '')
|
||||
if ((!modeSources[currentMode] || !modeSources[currentMode].token) && fallbackMode) {
|
||||
currentMode = fallbackMode
|
||||
}
|
||||
let controlsTimeout: number | undefined
|
||||
let isScrubbing = false
|
||||
let isHoveringVolume = false
|
||||
@@ -97,21 +100,18 @@ const initPlayer = (): void => {
|
||||
|
||||
const previewPopover = container.querySelector('[data-preview-popover]') as HTMLElement
|
||||
const previewTime = container.querySelector('[data-preview-time]') as HTMLElement
|
||||
const encodedModeState = encodeURIComponent(JSON.stringify(modeSources))
|
||||
|
||||
const streamUrlForMode = (mode: string): string => {
|
||||
const modeParam = encodeURIComponent(mode)
|
||||
const stateParam = encodedModeState
|
||||
return `${streamURL}?mode=${modeParam}&state=${stateParam}`
|
||||
const modeSource = modeSources[mode]
|
||||
const token = modeSource?.token
|
||||
if (!token) return ''
|
||||
const tokenParam = encodeURIComponent(token)
|
||||
return `${streamURL}?mode=${modeParam}&token=${tokenParam}`
|
||||
}
|
||||
|
||||
const subtitleProxyURL = (track: SubtitleItem): string => {
|
||||
if (!track || !track.url) return ''
|
||||
let proxied = `/watch/proxy/subtitle?u=${encodeURIComponent(track.url)}`
|
||||
if (track.referer) {
|
||||
proxied += `&r=${encodeURIComponent(track.referer)}`
|
||||
}
|
||||
return proxied
|
||||
if (!track || !track.token) return ''
|
||||
return `/watch/proxy/subtitle?token=${encodeURIComponent(track.token)}`
|
||||
}
|
||||
|
||||
const subtitlesForMode = (mode: string): Array<{ lang: string, label: string, url: string }> => {
|
||||
@@ -560,11 +560,13 @@ const initPlayer = (): void => {
|
||||
|
||||
const switchMode = (mode: string): void => {
|
||||
if (!availableModes.includes(mode) || mode === currentMode) return
|
||||
const nextURL = streamUrlForMode(mode)
|
||||
if (!nextURL) return
|
||||
const wasPlaying = !video.paused
|
||||
const previousTime = displayTimeFromAbsolute(video.currentTime)
|
||||
currentMode = mode
|
||||
hidePreviewPopover()
|
||||
video.src = streamUrlForMode(currentMode)
|
||||
video.src = nextURL
|
||||
video.load()
|
||||
pendingSeekTime = previousTime
|
||||
if (wasPlaying) video.play().catch(() => {})
|
||||
@@ -653,6 +655,13 @@ const initPlayer = (): void => {
|
||||
updateSubtitleOptions()
|
||||
updateModeButtons(currentMode)
|
||||
|
||||
const startingURL = streamUrlForMode(currentMode)
|
||||
if (startingURL) {
|
||||
video.src = startingURL
|
||||
} else if (initialStreamToken) {
|
||||
video.src = `${streamURL}?mode=${encodeURIComponent(currentMode)}&token=${encodeURIComponent(initialStreamToken)}`
|
||||
}
|
||||
|
||||
if (video) {
|
||||
video.addEventListener('loadedmetadata', () => {
|
||||
if (loading) loading.style.display = 'none'
|
||||
|
||||
Reference in New Issue
Block a user