cleanup: remove redundant and old architectural files

This commit is contained in:
2026-05-13 10:34:16 +02:00
parent 0d6c7613a9
commit ab31cf4c4c
27 changed files with 0 additions and 5954 deletions

View File

@@ -1,539 +0,0 @@
package anime
import (
"context"
"encoding/json"
"errors"
"html"
"log"
"net/http"
"strconv"
"strings"
"time"
"mal/integrations/jikan"
"mal/internal/db"
"mal/internal/middleware"
"mal/templates"
"golang.org/x/sync/errgroup"
)
type Handler struct {
service *Service
}
func NewHandler(service *Service) *Handler {
return &Handler{service: service}
}
type quickSearchResult struct {
ID int `json:"id"` // anime mal id
Title string `json:"title"` // display title
Type string `json:"type"` // anime type (tv, movie, etc)
Image string `json:"image"` // cover image url
}
func (h *Handler) HandleCatalog(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
renderNotFoundPage(r, w)
return
}
user := middleware.GetUser(r.Context())
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "index.gohtml", map[string]any{
"User": user,
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
func (h *Handler) HandleCatalogAiring(w http.ResponseWriter, r *http.Request) {
h.renderCatalogSection(w, r, "Airing")
}
func (h *Handler) HandleCatalogPopular(w http.ResponseWriter, r *http.Request) {
h.renderCatalogSection(w, r, "Popular")
}
func (h *Handler) HandleCatalogContinue(w http.ResponseWriter, r *http.Request) {
h.renderCatalogSection(w, r, "Continue")
}
// renderCatalogSection fetches catalog data (airing/popular/continue) and renders as htmx fragment
func (h *Handler) renderCatalogSection(w http.ResponseWriter, r *http.Request, section string) {
user := middleware.GetUser(r.Context())
userID := ""
if user != nil {
userID = user.ID
}
data, err := h.service.GetCatalogSection(r.Context(), userID, section)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("catalog %s error: %v", section, err)
}
if section != "Continue" {
writeInlineLoadError(w, "Failed to load "+section)
}
return
}
data["User"] = user
data["Section"] = section
// render section as htmx partial, not full page
if err := templates.GetRenderer().ExecuteFragment(r.Context(), w, "index.gohtml", "catalog_section", data); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("fragment render error: %v", err)
}
}
}
func (h *Handler) HandleDiscover(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r.Context())
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "discover.gohtml", map[string]any{
"User": user,
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
func (h *Handler) HandleDiscoverTrending(w http.ResponseWriter, r *http.Request) {
h.renderDiscoverSection(w, r, "Trending")
}
func (h *Handler) HandleDiscoverUpcoming(w http.ResponseWriter, r *http.Request) {
h.renderDiscoverSection(w, r, "Upcoming")
}
func (h *Handler) HandleDiscoverTop(w http.ResponseWriter, r *http.Request) {
h.renderDiscoverSection(w, r, "Top")
}
func (h *Handler) renderDiscoverSection(w http.ResponseWriter, r *http.Request, section string) {
user := middleware.GetUser(r.Context())
userID := ""
if user != nil {
userID = user.ID
}
data, err := h.service.GetDiscoverSection(r.Context(), userID, section)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("discover %s error: %v", section, err)
}
writeInlineLoadError(w, "Failed to load "+section)
return
}
data["User"] = user
data["Section"] = section
if err := templates.GetRenderer().ExecuteFragment(r.Context(), w, "discover.gohtml", "discover_section", data); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("fragment render error: %v", err)
}
}
}
// HandleBrowse handles anime search/browse with filters. supports htmx partial loading.
func (h *Handler) HandleBrowse(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r.Context())
// parse query params for search/filter
q := r.URL.Query().Get("q")
animeType := r.URL.Query().Get("type")
status := r.URL.Query().Get("status")
orderBy := r.URL.Query().Get("order_by")
sort := r.URL.Query().Get("sort")
sfw := r.URL.Query().Get("sfw") != "false" // default to safe
var genres []int
for _, g := range r.URL.Query()["genres"] {
id, err := strconv.Atoi(g)
if err == nil {
genres = append(genres, id)
}
}
page := parsePageParam(r)
ctx, cancel := context.WithTimeout(r.Context(), 20*time.Second)
defer cancel()
res, err := h.service.jikanClient.SearchAdvanced(ctx, q, animeType, status, orderBy, sort, genres, sfw, page, 24)
if err != nil {
if errors.Is(err, context.Canceled) {
return
}
log.Printf("browse error: %v", err)
}
if r.Header.Get("HX-Request") == "true" {
// htmx: return just the card scroll fragment with watchlist state
watchlistMap := make(map[int]bool)
if user != nil {
watchlist, _ := h.service.db.GetUserWatchList(ctx, user.ID)
for _, entry := range watchlist {
watchlistMap[int(entry.AnimeID)] = true
}
}
w.Header().Set("Content-Type", "text/html")
err := templates.GetRenderer().ExecuteFragment(ctx, w, "browse.gohtml", "anime_card_scroll", map[string]any{
"Animes": res.Animes,
"NextPage": page + 1,
"HasNextPage": res.HasNextPage,
"Query": q,
"Type": animeType,
"Status": status,
"OrderBy": orderBy,
"Sort": sort,
"Genres": genres,
"SFW": sfw,
"WatchlistMap": watchlistMap,
})
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("fragment render error: %v", err)
}
}
return
}
// full page load: fetch genres list and full watchlist
genresList, err := h.service.jikanClient.GetAnimeGenres(ctx)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("genres error: %v", err)
}
}
watchlistMap := make(map[int]bool)
var watchlistIDs []int64
if user != nil {
watchlist, _ := h.service.db.GetUserWatchList(ctx, user.ID)
watchlistIDs = make([]int64, len(watchlist))
for i, entry := range watchlist {
watchlistMap[int(entry.AnimeID)] = true
watchlistIDs[i] = entry.AnimeID
}
}
if err := templates.GetRenderer().ExecuteTemplate(ctx, w, "browse.gohtml", map[string]any{
"User": user,
"CurrentPath": r.URL.Path,
"Query": q,
"Type": animeType,
"Status": status,
"OrderBy": orderBy,
"Sort": sort,
"Genres": genres,
"SFW": sfw,
"GenresList": genresList,
"Animes": res.Animes,
"HasNextPage": res.HasNextPage,
"NextPage": page + 1,
"WatchlistMap": watchlistMap,
"WatchlistIDs": watchlistIDs,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
}
// HandleAnimeDetails renders anime detail page. handles htmx requests for characters/recommendations sections.
func (h *Handler) HandleAnimeDetails(w http.ResponseWriter, r *http.Request) {
idStr := strings.TrimPrefix(r.URL.Path, "/anime/")
idStr = strings.TrimSuffix(idStr, "/")
id, err := strconv.Atoi(idStr)
if err != nil {
renderNotFoundPage(r, w)
return
}
user := middleware.GetUser(r.Context())
// htmx: return just the section (characters or recommendations)
section := r.URL.Query().Get("section")
if section != "" && r.Header.Get("HX-Request") == "true" {
h.renderAnimeDetailsSection(w, r, id, section)
return
}
var (
anime jikan.Anime
status string
episodesCount int
watchlistIDs []int64
)
g, gCtx := errgroup.WithContext(r.Context())
// fetch anime details + episode count if airing
g.Go(func() error {
var err error
anime, err = h.service.jikanClient.GetAnimeByID(gCtx, id)
if err == nil && anime.Airing {
// get episode count for airing anime (may span multiple pages)
eps, err := h.service.jikanClient.GetEpisodes(gCtx, id, 1)
if err == nil {
if eps.Pagination.LastVisiblePage > 1 {
lastEps, err := h.service.jikanClient.GetEpisodes(gCtx, id, eps.Pagination.LastVisiblePage)
if err == nil && len(lastEps.Data) > 0 {
lastEp := lastEps.Data[len(lastEps.Data)-1]
count, _ := strconv.Atoi(lastEp.Episode)
episodesCount = count
}
} else if len(eps.Data) > 0 {
lastEp := eps.Data[len(eps.Data)-1]
count, _ := strconv.Atoi(lastEp.Episode)
episodesCount = count
}
}
}
return err
})
if user != nil {
// fetch user's watchlist status for this anime
g.Go(func() error {
entry, err := h.service.db.GetWatchListEntry(gCtx, db.GetWatchListEntryParams{
UserID: user.ID,
AnimeID: int64(id),
})
if err == nil {
status = entry.Status
}
return nil
})
// fetch all watchlist ids for nav state
g.Go(func() error {
watchlist, err := h.service.db.GetUserWatchList(gCtx, user.ID)
if err == nil {
watchlistIDs = make([]int64, len(watchlist))
for i, e := range watchlist {
watchlistIDs[i] = e.AnimeID
}
}
return nil
})
}
if err := g.Wait(); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("anime details fetch error: %v", err)
}
renderNotFoundPage(r, w)
return
}
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "anime.gohtml", map[string]any{
"Anime": anime,
"User": user,
"Status": status,
"CurrentPath": r.URL.Path,
"WatchlistIDs": watchlistIDs,
"EpisodesCount": episodesCount,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
// renderAnimeDetailsSection fetches and renders htmx partial for character/recommendation sections
func (h *Handler) renderAnimeDetailsSection(w http.ResponseWriter, r *http.Request, id int, section string) {
ctx := r.Context()
var data any
var err error
switch section {
case "characters":
data, err = h.service.jikanClient.GetAnimeCharacters(ctx, id)
case "recommendations":
data, err = h.service.jikanClient.GetAnimeRecommendations(ctx, id)
default:
http.Error(w, "Invalid section", http.StatusBadRequest)
return
}
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("anime details %s error: %v", section, err)
}
writeInlineLoadError(w, "Failed to load "+section)
return
}
tplName := "anime_characters"
if section == "recommendations" {
tplName = "anime_recommendations"
}
// render htmx partial for the section
if err := templates.GetRenderer().ExecuteFragment(ctx, w, "anime.gohtml", tplName, data); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("fragment render error: %v", err)
}
}
}
func (h *Handler) HandleHTMLWatchOrder(w http.ResponseWriter, r *http.Request) {
animeIdStr := r.URL.Query().Get("animeId")
id, err := strconv.Atoi(animeIdStr)
if err != nil {
http.Error(w, `<div class="mt-8 text-sm text-red-400">Invalid anime ID.</div>`, http.StatusBadRequest)
return
}
relations, err := h.service.jikanClient.GetFullRelations(r.Context(), id)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("watch order error: %v", err)
}
http.Error(w, `<div class="mt-8 text-sm text-red-400">Failed to load watch order.</div>`, http.StatusInternalServerError)
return
}
user := middleware.GetUser(r.Context())
watchlistMap := make(map[int64]bool)
if user != nil {
watchlist, _ := h.service.db.GetUserWatchList(r.Context(), user.ID)
for _, entry := range watchlist {
watchlistMap[entry.AnimeID] = true
}
}
if err := templates.GetRenderer().ExecuteFragment(r.Context(), w, "anime.gohtml", "watch_order", map[string]any{
"Relations": relations,
"AnimeID": id,
"WatchlistMap": watchlistMap,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
}
func (h *Handler) HandleQuickSearch(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
query := r.URL.Query().Get("q")
if query == "" {
w.WriteHeader(http.StatusOK)
if err := writeJSON(w, []quickSearchResult{}); err != nil {
log.Printf("quick search encode error: %v", err)
}
return
}
res, err := h.service.jikanClient.SearchAdvanced(r.Context(), query, "", "", "", "", nil, true, 1, 5)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("quick search error: %v", err)
}
w.WriteHeader(http.StatusOK)
if err := writeJSON(w, []quickSearchResult{}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("quick search encode error: %v", err)
}
}
return
}
output := make([]quickSearchResult, len(res.Animes))
for i, anime := range res.Animes {
output[i] = quickSearchResult{
ID: anime.MalID,
Title: anime.DisplayTitle(),
Type: anime.Type,
Image: anime.ImageURL(),
}
}
w.WriteHeader(http.StatusOK)
if err := writeJSON(w, output); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("quick search encode error: %v", err)
}
}
}
func (h *Handler) HandleRandomAnime(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
anime, err := h.service.jikanClient.GetRandomAnime(r.Context())
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("random anime error: %v", err)
}
w.WriteHeader(http.StatusInternalServerError)
if err := writeJSON(w, map[string]string{"error": "Failed to fetch random anime"}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("random anime encode error: %v", err)
}
}
return
}
if anime.MalID == 0 {
w.WriteHeader(http.StatusNotFound)
if err := writeJSON(w, map[string]string{"error": "No anime found"}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("random anime encode error: %v", err)
}
}
return
}
w.WriteHeader(http.StatusOK)
if err := writeJSON(w, map[string]any{"data": anime}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("random anime encode error: %v", err)
}
}
}
func (h *Handler) HandleSearch(w http.ResponseWriter, r *http.Request) {
renderNotFoundPage(r, w)
}
func renderNotFoundPage(r *http.Request, w http.ResponseWriter) {
w.WriteHeader(http.StatusNotFound)
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "not_found.gohtml", map[string]any{
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
func writeInlineLoadError(w http.ResponseWriter, message string) {
w.Header().Set("Content-Type", "text/html")
_, _ = w.Write([]byte(`<p style="color: var(--text-muted); font-size: var(--text-sm);">` + html.EscapeString(message) + `</p>`))
}
func writeJSON(w http.ResponseWriter, v any) error {
return json.NewEncoder(w).Encode(v)
}
func parsePageParam(r *http.Request) int {
page, _ := strconv.Atoi(r.URL.Query().Get("page"))
if page < 1 {
return 1
}
return page
}

View File

@@ -1,130 +0,0 @@
package anime
import (
"context"
"mal/integrations/jikan"
"mal/internal/db"
"golang.org/x/sync/errgroup"
)
type Service struct {
jikanClient *jikan.Client
db db.Querier
}
func NewService(jikanClient *jikan.Client, db db.Querier) *Service {
return &Service{jikanClient: jikanClient, db: db}
}
// GetCatalogSection fetches homepage catalog sections (Airing, Popular, Continue) from jikan and db.
func (s *Service) GetCatalogSection(ctx context.Context, userID string, section string) (map[string]any, error) {
var (
res jikan.TopAnimeResult
cw []db.GetContinueWatchingEntriesRow
watchlist []db.GetUserWatchListRow
err error
)
g, gCtx := errgroup.WithContext(ctx)
// fetch jikan data (season now or top anime)
g.Go(func() error {
switch section {
case "Airing":
res, err = s.jikanClient.GetSeasonsNow(gCtx, 1)
case "Popular":
res, err = s.jikanClient.GetTopAnime(gCtx, 1)
}
return err
})
// fetch user-specific data if logged in
if userID != "" {
g.Go(func() error {
if section == "Continue" {
var err error
cw, err = s.db.GetContinueWatchingEntries(gCtx, userID)
return err
}
return nil
})
g.Go(func() error {
var err error
watchlist, err = s.db.GetUserWatchList(gCtx, userID)
return err
})
}
if err := g.Wait(); err != nil {
return nil, err
}
// limit to 6 items for homepage grid
animes := res.Animes
if len(animes) > 6 {
animes = animes[:6]
}
watchlistMap := make(map[int64]bool)
for _, entry := range watchlist {
watchlistMap[entry.AnimeID] = true
}
return map[string]any{
"Animes": animes,
"ContinueWatching": cw,
"WatchlistMap": watchlistMap,
}, nil
}
// GetDiscoverSection fetches discover page sections (Trending, Upcoming, Top) from jikan.
func (s *Service) GetDiscoverSection(ctx context.Context, userID string, section string) (map[string]any, error) {
var (
res jikan.TopAnimeResult
watchlist []db.GetUserWatchListRow
err error
)
g, gCtx := errgroup.WithContext(ctx)
g.Go(func() error {
switch section {
case "Trending":
res, err = s.jikanClient.GetSeasonsNow(gCtx, 1)
case "Upcoming":
res, err = s.jikanClient.GetSeasonsUpcoming(gCtx, 1)
case "Top":
res, err = s.jikanClient.GetTopAnime(gCtx, 1)
}
return err
})
if userID != "" {
g.Go(func() error {
var err error
watchlist, err = s.db.GetUserWatchList(gCtx, userID)
return err
})
}
if err := g.Wait(); err != nil {
return nil, err
}
// limit to 8 items for discover grid
animes := res.Animes
if len(animes) > 8 {
animes = animes[:8]
}
watchlistMap := make(map[int64]bool)
for _, entry := range watchlist {
watchlistMap[entry.AnimeID] = true
}
return map[string]any{
"Animes": animes,
"WatchlistMap": watchlistMap,
}, nil
}

View File

@@ -1,127 +0,0 @@
package auth
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"errors"
"fmt"
"net/http"
"os"
"time"
"golang.org/x/crypto/bcrypt"
"mal/internal/db"
)
var (
ErrInvalidCredentials = errors.New("invalid username or password")
ErrNotAuthenticated = errors.New("not authenticated")
)
type Service struct {
db db.Querier
}
func NewService(db db.Querier) *Service {
return &Service{db: db}
}
// generateToken creates a cryptographically random base64-encoded token
func generateToken(size int) (string, error) {
b := make([]byte, size)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// generateSessionToken creates a 32-byte session token
func generateSessionToken() (string, error) {
return generateToken(32)
}
func (s *Service) Login(ctx context.Context, username, password string) (*db.Session, error) {
user, err := s.db.GetUserByUsername(ctx, username)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrInvalidCredentials
}
return nil, fmt.Errorf("failed to lookup user: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return nil, ErrInvalidCredentials
}
token, err := generateSessionToken()
if err != nil {
return nil, fmt.Errorf("failed to generate session token: %w", err)
}
expiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days
session, err := s.db.CreateSession(ctx, db.CreateSessionParams{
ID: token,
UserID: user.ID,
ExpiresAt: expiresAt,
})
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &session, nil
}
func (s *Service) ValidateSession(ctx context.Context, sessionID string) (*db.User, error) {
session, err := s.db.GetSession(ctx, sessionID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotAuthenticated
}
return nil, fmt.Errorf("failed to get session: %w", err)
}
if time.Now().After(session.ExpiresAt) {
_ = s.db.DeleteSession(ctx, sessionID) // clean up expired session
return nil, ErrNotAuthenticated
}
user, err := s.db.GetUser(ctx, session.UserID)
if err != nil {
return nil, fmt.Errorf("failed to get user for session: %w", err)
}
return &user, nil
}
// SetSessionCookie sets an http-only, secure session cookie
func SetSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Time) {
secure := os.Getenv("ENV") == "production" || os.Getenv("FORCE_SECURE_COOKIES") == "true"
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Expires: expiresAt,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
}
func (s *Service) Logout(ctx context.Context, sessionID string) error {
return s.db.DeleteSession(ctx, sessionID)
}
// ClearSessionCookie invalidates the session cookie
func ClearSessionCookie(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "",
Expires: time.Unix(0, 0), // epoch to expire immediately
MaxAge: -1,
HttpOnly: true,
Path: "/",
})
}

View File

@@ -1,90 +0,0 @@
package auth
import (
"context"
"errors"
"log"
"net/http"
"mal/templates"
)
type Handler struct {
authService *Service
}
func NewHandler(authService *Service) *Handler {
return &Handler{authService: authService}
}
// HandleLoginPage renders the login form
func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "login.gohtml", map[string]any{
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
}
// HandleLogin validates credentials and creates a session on success
func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "login.gohtml", map[string]any{
"Error": "Something went wrong. Please try again.",
"Username": "",
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
return
}
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "login.gohtml", map[string]any{
"Error": "The email or password is wrong.",
"Username": username,
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
return
}
session, err := h.authService.Login(r.Context(), username, password)
if err != nil {
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "login.gohtml", map[string]any{
"Error": "The email or password is wrong.",
"Username": username,
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
return
}
SetSessionCookie(w, session.ID, session.ExpiresAt)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
// HandleLogout destroys the session and clears the cookie
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err == nil {
_ = h.authService.Logout(r.Context(), cookie.Value)
}
ClearSessionCookie(w)
http.Redirect(w, r, "/", http.StatusSeeOther)
}

View File

@@ -1,695 +0,0 @@
package playback
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mal/pkg/net/utls"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
allAnimeBaseURL = "https://api.allanime.day"
allAnimeReferer = "https://allmanga.to/"
allAnimeOrigin = "https://youtu-chan.com"
defaultUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/121.0"
)
var (
aesKeys = []string{"Xot36i3lK3:v1", "SimtVuagFbGR2K7P"}
)
var allAnimeUTLSClient = &http.Client{
Transport: &utls.UtlsRoundTripper{},
Timeout: 30 * time.Second,
}
type searchResult struct {
ID string
MalID string
Name string
}
type AvailableEpisodes struct {
Sub []string
Dub []string
Raw []string
}
type allAnimeClient struct {
httpClient *http.Client
extractor *providerExtractor
}
func newAllAnimeClient() *allAnimeClient {
return &allAnimeClient{
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
extractor: newProviderExtractor(),
}
}
func (c *allAnimeClient) graphqlRequest(ctx context.Context, query string, variables map[string]any) (map[string]any, error) {
if mode, ok := variables["translationType"].(string); ok {
variables["translationType"] = strings.ToLower(mode)
}
payload := map[string]any{
"query": query,
"variables": variables,
}
body, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal graphql payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, allAnimeBaseURL+"/api", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("create graphql request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Referer", allAnimeReferer)
req.Header.Set("User-Agent", defaultUserAgent)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("execute graphql request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
if err != nil {
return nil, fmt.Errorf("read graphql response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("graphql status %d", resp.StatusCode)
}
var parsed map[string]any
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("decode graphql response: %w", err)
}
if errs, ok := parsed["errors"].([]any); ok && len(errs) > 0 {
return nil, fmt.Errorf("graphql error: %v", errs[0])
}
return parsed, nil
}
const episodeQueryHash = "d405d0edd690624b66baba3068e0edc3ac90f1597d898a1ec8db4e5c43c00fec"
func (c *allAnimeClient) graphqlRequestWithHash(ctx context.Context, showID, episode, mode string) (map[string]any, error) {
mode = strings.ToLower(mode)
varsJSON := fmt.Sprintf(`{"showId":"%s","translationType":"%s","episodeString":"%s"}`, showID, mode, episode)
extJSON := fmt.Sprintf(`{"persistedQuery":{"version":1,"sha256Hash":"%s"}}`, episodeQueryHash)
apiURL := fmt.Sprintf("%s/api?variables=%s&extensions=%s",
allAnimeBaseURL,
url.QueryEscape(varsJSON),
url.QueryEscape(extJSON))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
req.Header.Set("User-Agent", defaultUserAgent)
req.Header.Set("Accept", "*/*")
req.Header.Set("Accept-Language", "en-US,en;q=0.5")
req.Header.Set("Accept-Encoding", "identity")
req.Header.Set("Referer", allAnimeReferer)
req.Header.Set("Origin", allAnimeOrigin)
req.Header.Set("Sec-Fetch-Dest", "empty")
req.Header.Set("Sec-Fetch-Mode", "cors")
req.Header.Set("Sec-Fetch-Site", "cross-site")
resp, err := allAnimeUTLSClient.Do(req)
if err != nil {
return nil, fmt.Errorf("execute GET request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1022))
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GET status %d: %s", resp.StatusCode, string(respBody))
}
var parsed map[string]any
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("decode response: %w", err)
}
if errs, ok := parsed["errors"].([]any); ok && len(errs) > 0 {
return nil, fmt.Errorf("graphql error: %v", errs[0])
}
data, ok := parsed["data"].(map[string]any)
if !ok {
return nil, fmt.Errorf("no data in response")
}
var toBeParsed string
if s, ok := data["tobeparsed"].(string); ok && s != "" {
toBeParsed = s
} else if episodeData, ok := data["episode"].(map[string]any); ok {
if s, ok := episodeData["tobeparsed"].(string); ok {
toBeParsed = s
}
}
if toBeParsed != "" {
decrypted, err := decryptTobeparsed(toBeParsed)
if err != nil {
return nil, fmt.Errorf("decrypt tobeparsed: %w", err)
}
var ep map[string]any
if jerr := json.Unmarshal(decrypted, &ep); jerr != nil {
return nil, fmt.Errorf("unmarshal decrypted: %w", jerr)
}
var sourceURLs []any
if srcs, ok := ep["sourceUrls"].([]any); ok {
sourceURLs = srcs
} else if epInner, ok := ep["episode"].(map[string]any); ok {
if srcs, ok := epInner["sourceUrls"].([]any); ok {
sourceURLs = srcs
}
}
if len(sourceURLs) > 0 {
return map[string]any{
"episode": map[string]any{
"sourceUrls": sourceURLs,
},
}, nil
}
}
if episodeData, ok := data["episode"].(map[string]any); ok {
if srcs, ok := episodeData["sourceUrls"].([]any); ok && len(srcs) > 0 {
return parsed, nil
}
}
return nil, fmt.Errorf("no usable data in response")
}
// GetEpisodeSources fetches stream URLs for a given show, episode, and mode (dub/sub).
func (c *allAnimeClient) GetEpisodeSources(ctx context.Context, showID string, episode string, mode string) ([]StreamSource, error) {
episodeQuery := `query($showId: String!, $translationType: VaildTranslationTypeEnumType!, $episodeString: String!) {
episode(showId: $showId, translationType: $translationType, episodeString: $episodeString) {
sourceUrls
}
}`
result, err := c.graphqlRequestWithHash(ctx, showID, episode, mode)
if err == nil {
sources := c.extractSourceURLsFromData(ctx, result)
if len(sources) > 0 {
return sources, nil
}
}
result, err = c.graphqlRequest(ctx, episodeQuery, map[string]any{
"showId": showID,
"translationType": mode,
"episodeString": episode,
})
if err != nil {
return nil, err
}
data, ok := result["data"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid source response")
}
rawSourceURLs, ok := data["episode"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid episode response")
}
sourceURLs, ok := rawSourceURLs["sourceUrls"].([]any)
if !ok || len(sourceURLs) == 0 {
return nil, fmt.Errorf("no source urls")
}
references := buildSourceReferences(sourceURLs)
if len(references) == 0 {
return nil, fmt.Errorf("no source references")
}
out := make([]StreamSource, 0, len(references))
for _, ref := range references {
target := strings.TrimSpace(ref.URL)
if target == "" {
continue
}
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
sourceType := detectStreamType(target)
if sourceType == "unknown" {
sourceType = detectEmbedType(target)
}
out = append(out, buildStreamSource(target, sourceType, ref.Name))
continue
}
decoded := decodeSourceURL(target)
if decoded == "" {
continue
}
if strings.HasPrefix(decoded, "http://") || strings.HasPrefix(decoded, "https://") {
sourceType := detectStreamType(decoded)
if sourceType == "unknown" {
sourceType = detectEmbedType(decoded)
}
out = append(out, buildStreamSource(decoded, sourceType, ref.Name))
continue
}
if !strings.HasPrefix(decoded, "/") {
decoded = "/" + decoded
}
extracted, err := c.extractor.ExtractVideoLinks(ctx, decoded)
if err != nil {
continue
}
out = append(out, extracted...)
}
if len(out) == 0 {
return nil, fmt.Errorf("no playable sources extracted")
}
return out, nil
}
func (c *allAnimeClient) extractSourceURLsFromData(ctx context.Context, data map[string]any) []StreamSource {
episodeData, ok := data["episode"].(map[string]any)
if !ok {
return nil
}
sourceURLs, ok := episodeData["sourceUrls"].([]any)
if !ok || len(sourceURLs) == 0 {
return nil
}
references := buildSourceReferences(sourceURLs)
if len(references) == 0 {
return nil
}
out := make([]StreamSource, 0, len(references))
for _, ref := range references {
target := strings.TrimSpace(ref.URL)
if target == "" {
continue
}
if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") {
sourceType := detectStreamType(target)
if sourceType == "unknown" {
sourceType = detectEmbedType(target)
}
out = append(out, buildStreamSource(target, sourceType, ref.Name))
continue
}
decoded := decodeSourceURL(target)
if decoded == "" {
continue
}
if strings.HasPrefix(decoded, "http://") || strings.HasPrefix(decoded, "https://") {
sourceType := detectStreamType(decoded)
if sourceType == "unknown" {
sourceType = detectEmbedType(decoded)
}
out = append(out, buildStreamSource(decoded, sourceType, ref.Name))
continue
}
if !strings.HasPrefix(decoded, "/") {
decoded = "/" + decoded
}
extracted, err := c.extractor.ExtractVideoLinks(ctx, decoded)
if err != nil {
continue
}
out = append(out, extracted...)
}
return out
}
func buildStreamSource(url, sourceType, provider string) StreamSource {
return StreamSource{
URL: url,
Provider: provider,
Type: sourceType,
Referer: allAnimeReferer,
}
}
type sourceReference struct {
URL string
Name string
}
// buildSourceReferences orders source URLs by provider priority, deduplicating entries.
func buildSourceReferences(rawSourceURLs []any) []sourceReference {
priorityOrder := []string{"default", "yt-mp4", "s-mp4", "luf-mp4"}
prioritySet := map[string]struct{}{"default": {}, "yt-mp4": {}, "s-mp4": {}, "luf-mp4": {}}
prioritized := make(map[string]sourceReference)
fallback := make([]sourceReference, 0, len(rawSourceURLs))
seen := make(map[string]struct{})
for _, source := range rawSourceURLs {
item, ok := source.(map[string]any)
if !ok {
continue
}
sourceURL, _ := item["sourceUrl"].(string)
sourceName, _ := item["sourceName"].(string)
sourceURL = strings.TrimSpace(sourceURL)
sourceName = strings.TrimSpace(sourceName)
if sourceURL == "" {
continue
}
if _, exists := seen[sourceURL]; exists {
continue
}
seen[sourceURL] = struct{}{}
ref := sourceReference{URL: sourceURL, Name: sourceName}
normalized := strings.ToLower(sourceName)
// separate prioritized providers from fallback
if _, prioritizedProvider := prioritySet[normalized]; prioritizedProvider {
if _, exists := prioritized[normalized]; !exists {
prioritized[normalized] = ref
}
continue
}
fallback = append(fallback, ref)
}
// output: prioritized in order, then fallback
ordered := make([]sourceReference, 0, len(prioritized)+len(fallback))
for _, provider := range priorityOrder {
if ref, ok := prioritized[provider]; ok {
ordered = append(ordered, ref)
}
}
ordered = append(ordered, fallback...)
return ordered
}
func decryptTobeparsed(encoded string) ([]byte, error) {
raw, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, fmt.Errorf("base64 decode failed: %w", err)
}
if len(raw) < 29 {
return nil, fmt.Errorf("encrypted payload too short")
}
version := raw[0]
iv := raw[1:13]
cipherText := raw[13 : len(raw)-16]
for _, keyStr := range aesKeys {
key := sha256.Sum256([]byte(keyStr))
block, err := aes.NewCipher(key[:])
if err != nil {
continue
}
if version == 1 {
plainText := tryDecryptCTR(block, iv, cipherText)
if json.Valid(plainText) {
return plainText, nil
}
}
gcm, err := cipher.NewGCM(block)
if err == nil {
tag := raw[len(raw)-16:]
combined := append(append([]byte{}, cipherText...), tag...)
plainText, openErr := gcm.Open(nil, iv, combined, nil)
if openErr == nil && json.Valid(plainText) {
return plainText, nil
}
}
}
return nil, fmt.Errorf("decryption failed")
}
func tryDecryptCTR(block cipher.Block, iv []byte, cipherText []byte) []byte {
ctrIV := append([]byte{}, iv...)
ctrIV = append(ctrIV, 0x00, 0x00, 0x00, 0x02)
ctr := cipher.NewCTR(block, ctrIV)
plainText := make([]byte, len(cipherText))
ctr.XORKeyStream(plainText, cipherText)
return plainText
}
// Search queries AllAnime for shows matching the given search term.
func (c *allAnimeClient) Search(ctx context.Context, query string, mode string) ([]searchResult, error) {
graphqlQuery := `query($search: SearchInput, $limit: Int, $page: Int, $translationType: VaildTranslationTypeEnumType, $countryOrigin: VaildCountryOriginEnumType) {
shows(search: $search, limit: $limit, page: $page, translationType: $translationType, countryOrigin: $countryOrigin) {
edges {
_id
malId
name
}
}
}`
variables := map[string]any{
"search": map[string]any{
"allowAdult": false,
"allowUnknown": false,
"query": query,
},
"limit": 40,
"page": 1,
"translationType": mode,
"countryOrigin": "ALL",
}
result, err := c.graphqlRequest(ctx, graphqlQuery, variables)
if err != nil {
return nil, err
}
data, ok := result["data"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid search response")
}
shows, ok := data["shows"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid shows payload")
}
edges, ok := shows["edges"].([]any)
if !ok {
return nil, fmt.Errorf("invalid search edges")
}
out := make([]searchResult, 0, len(edges))
for _, edge := range edges {
item, ok := edge.(map[string]any)
if !ok {
continue
}
id, _ := item["_id"].(string)
malID, _ := item["malId"].(string)
name, _ := item["name"].(string)
if unquoted, err := strconv.Unquote("\"" + name + "\""); err == nil {
name = unquoted
}
name = strings.TrimSpace(name)
if id == "" {
continue
}
out = append(out, searchResult{ID: id, MalID: malID, Name: name})
}
return out, nil
}
// GetAvailableEpisodes returns the count of sub/dub/raw episodes available for a show.
func (c *allAnimeClient) GetAvailableEpisodes(ctx context.Context, showID string) (AvailableEpisodes, error) {
graphqlQuery := `query($showId: String!) {
show(_id: $showId) {
availableEpisodesDetail
lastEpisodeInfo
}
}`
result, err := c.graphqlRequest(ctx, graphqlQuery, map[string]any{"showId": showID})
if err != nil {
return AvailableEpisodes{}, err
}
data, ok := result["data"].(map[string]any)
if !ok {
return AvailableEpisodes{}, fmt.Errorf("invalid response")
}
show, ok := data["show"].(map[string]any)
if !ok || show == nil {
return AvailableEpisodes{}, fmt.Errorf("show not found")
}
detail, ok := show["availableEpisodesDetail"].(map[string]any)
if !ok {
return AvailableEpisodes{}, fmt.Errorf("invalid detail")
}
var count AvailableEpisodes
if sub, ok := detail["sub"].([]any); ok {
for _, s := range sub {
if str, ok := s.(string); ok {
count.Sub = append(count.Sub, str)
}
}
}
if dub, ok := detail["dub"].([]any); ok {
for _, s := range dub {
if str, ok := s.(string); ok {
count.Dub = append(count.Dub, str)
}
}
}
if raw, ok := detail["raw"].([]any); ok {
for _, s := range raw {
if str, ok := s.(string); ok {
count.Raw = append(count.Raw, str)
}
}
}
return count, nil
}
func decodeSourceURL(encoded string) string {
if encoded == "" {
return ""
}
encoded = strings.TrimPrefix(encoded, "--")
substitutions := map[string]string{
"79": "A", "7a": "B", "7b": "C", "7c": "D", "7d": "E",
"7e": "F", "7f": "G", "70": "H", "71": "I", "72": "J",
"73": "K", "74": "L", "75": "M", "76": "N", "77": "O",
"68": "P", "69": "Q", "6a": "R", "6b": "S", "6c": "T",
"6d": "U", "6e": "V", "6f": "W", "60": "X", "61": "Y",
"62": "Z",
"59": "a", "5a": "b", "5b": "c", "5c": "d", "5d": "e",
"5e": "f", "5f": "g", "50": "h", "51": "i", "52": "j",
"53": "k", "54": "l", "55": "m", "56": "n", "57": "o",
"48": "p", "49": "q", "4a": "r", "4b": "s", "4c": "t",
"4d": "u", "4e": "v", "4f": "w", "40": "x", "41": "y",
"42": "z",
"08": "0", "09": "1", "0a": "2", "0b": "3", "0c": "4",
"0d": "5", "0e": "6", "0f": "7", "00": "8", "01": "9",
"15": "-", "16": ".", "67": "_", "46": "~", "02": ":",
"17": "/", "07": "?", "1b": "#", "63": "[", "65": "]",
"78": "@", "19": "!", "1c": "$", "1e": "&", "10": "(",
"11": ")", "12": "*", "13": "+", "14": ",", "03": ";",
"05": "=", "1d": "%",
}
var result strings.Builder
for idx := 0; idx < len(encoded); {
if idx+2 <= len(encoded) {
pair := encoded[idx : idx+2]
if sub, ok := substitutions[pair]; ok {
result.WriteString(sub)
idx += 2
continue
}
}
result.WriteByte(encoded[idx])
idx++
}
decoded := result.String()
if strings.Contains(decoded, "/clock") && !strings.Contains(decoded, "/clock.json") {
decoded = strings.Replace(decoded, "/clock", "/clock.json", 1)
}
return decoded
}
func detectStreamType(sourceURL string) string {
lower := strings.ToLower(sourceURL)
if strings.Contains(lower, ".m3u8") || strings.Contains(lower, "master.m3u8") {
return "m3u8"
}
if strings.Contains(lower, ".mp4") {
return "mp4"
}
return "unknown"
}
func detectEmbedType(rawURL string) string {
lower := strings.ToLower(rawURL)
embedHosts := []string{"streamwish", "streamsb", "mp4upload", "ok.ru", "gogoplay", "streamlare"}
for _, host := range embedHosts {
if strings.Contains(lower, host) {
return "embed"
}
}
return "unknown"
}

View File

@@ -1,454 +0,0 @@
package playback
import (
"context"
"crypto/aes"
"encoding/json"
"testing"
)
func TestDecodeSourceURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
encoded string
want string
}{
{
name: "empty returns empty",
encoded: "",
want: "",
},
{
name: "with double prefix stripped",
encoded: "--example.com/video.mp4",
want: "example.com/video.mp4",
},
{
name: "hex substitution",
encoded: "7aexample",
want: "Bexample",
},
{
name: "mixed substitution",
encoded: "79url7a01",
want: "AurlB9",
},
{
name: "clock replacement",
encoded: "/clock",
want: "/clock.json",
},
{
name: "no clock replacement if already json",
encoded: "/clock.json",
want: "/clock.json",
},
{
name: "complex url",
encoded: "--79stream7acom",
want: "AstreamBcom",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := decodeSourceURL(tt.encoded)
if got != tt.want {
t.Errorf("decodeSourceURL(%q) = %q, want %q", tt.encoded, got, tt.want)
}
})
}
}
func TestDetectStreamType(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
wantType string
}{
{
name: "m3u8 extension",
url: "https://example.com/video.m3u8",
wantType: "m3u8",
},
{
name: "master m3u8",
url: "https://example.com/master.m3u8",
wantType: "m3u8",
},
{
name: "mp4 extension",
url: "https://example.com/video.mp4",
wantType: "mp4",
},
{
name: "unknown",
url: "https://example.com/video.avi",
wantType: "unknown",
},
{
name: "empty returns unknown",
url: "",
wantType: "unknown",
},
{
name: "case insensitive - M3U8",
url: "https://example.com/MASTER.M3U8",
wantType: "m3u8",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := detectStreamType(tt.url)
if got != tt.wantType {
t.Errorf("detectStreamType(%q) = %q, want %q", tt.url, got, tt.wantType)
}
})
}
}
func TestDetectEmbedType(t *testing.T) {
t.Parallel()
tests := []struct {
name string
url string
wantType string
}{
{
name: "streamwish",
url: "https://streamwish.com/e/abc123",
wantType: "embed",
},
{
name: "streamsb",
url: "https://streamsb.com/e/abc123",
wantType: "embed",
},
{
name: "mp4upload",
url: "https://mp4upload.com/e/abc123",
wantType: "embed",
},
{
name: "ok.ru",
url: "https://ok.ru/video/123",
wantType: "embed",
},
{
name: "gogoplay",
url: "https://gogoplay.io/embed/123",
wantType: "embed",
},
{
name: "streamlare",
url: "https://streamlare.com/e/abc",
wantType: "embed",
},
{
name: "unknown host",
url: "https://unknown.com/video",
wantType: "unknown",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := detectEmbedType(tt.url)
if got != tt.wantType {
t.Errorf("detectEmbedType(%q) = %q, want %q", tt.url, got, tt.wantType)
}
})
}
}
func TestBuildStreamSource(t *testing.T) {
t.Parallel()
t.Run("constructs with correct defaults", func(t *testing.T) {
got := buildStreamSource("https://example.com/video.mp4", "mp4", "test-provider")
if got.URL != "https://example.com/video.mp4" {
t.Errorf("URL = %q, want %q", got.URL, "https://example.com/video.mp4")
}
if got.Provider != "test-provider" {
t.Errorf("Provider = %q, want %q", got.Provider, "test-provider")
}
if got.Type != "mp4" {
t.Errorf("Type = %q, want %q", got.Type, "mp4")
}
if got.Referer != allAnimeReferer {
t.Errorf("Referer = %q, want %q", got.Referer, allAnimeReferer)
}
})
}
func TestBuildSourceReferences(t *testing.T) {
t.Parallel()
tests := []struct {
name string
rawURLs []any
wantRefs []sourceReference
}{
{
name: "empty returns empty",
rawURLs: nil,
wantRefs: nil,
},
{
name: "filters empty URLs",
rawURLs: []any{
map[string]any{"sourceUrl": "", "sourceName": "test"},
map[string]any{"sourceUrl": "https://example.com/v.mp4", "sourceName": "default"},
},
wantRefs: []sourceReference{
{URL: "https://example.com/v.mp4", Name: "default"},
},
},
{
name: "deduplicates URLs",
rawURLs: []any{
map[string]any{"sourceUrl": "https://example.com/v.mp4", "sourceName": "test"},
map[string]any{"sourceUrl": "https://example.com/v.mp4", "sourceName": "test2"},
},
wantRefs: []sourceReference{
{URL: "https://example.com/v.mp4", Name: "test"},
},
},
{
name: "prioritizes default provider",
rawURLs: []any{
map[string]any{"sourceUrl": "https://a.com/v.mp4", "sourceName": "fallback"},
map[string]any{"sourceUrl": "https://b.com/v.mp4", "sourceName": "default"},
map[string]any{"sourceUrl": "https://c.com/v.mp4", "sourceName": "yt-mp4"},
},
wantRefs: []sourceReference{
{URL: "https://b.com/v.mp4", Name: "default"},
{URL: "https://c.com/v.mp4", Name: "yt-mp4"},
{URL: "https://a.com/v.mp4", Name: "fallback"},
},
},
{
name: "skips invalid map entries",
rawURLs: []any{
"invalid",
123,
map[string]any{"sourceUrl": "https://example.com/v.mp4"},
},
wantRefs: []sourceReference{
{URL: "https://example.com/v.mp4", Name: ""},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := buildSourceReferences(tt.rawURLs)
if len(got) != len(tt.wantRefs) {
t.Errorf("got %d refs, want %d", len(got), len(tt.wantRefs))
return
}
for i, want := range tt.wantRefs {
if got[i].URL != want.URL {
t.Errorf("ref[%d].URL = %q, want %q", i, got[i].URL, want.URL)
}
if got[i].Name != want.Name {
t.Errorf("ref[%d].Name = %q, want %q", i, got[i].Name, want.Name)
}
}
})
}
}
func TestBuildSourceReferencesOrder(t *testing.T) {
t.Parallel()
rawURLs := []any{
map[string]any{"sourceUrl": "https://s.com/v.mp4", "sourceName": "s-mp4"},
map[string]any{"sourceUrl": "https://default.com/v.mp4", "sourceName": "default"},
map[string]any{"sourceUrl": "https://luf.com/v.mp4", "sourceName": "luf-mp4"},
map[string]any{"sourceUrl": "https://yt.com/v.mp4", "sourceName": "yt-mp4"},
}
got := buildSourceReferences(rawURLs)
wantOrder := []string{"default", "yt-mp4", "s-mp4", "luf-mp4"}
if len(got) != len(wantOrder) {
t.Fatalf("got %d refs, want %d", len(got), len(wantOrder))
}
for i, wantName := range wantOrder {
if got[i].Name != wantName {
t.Errorf("ref[%d].Name = %q, want %q (priority order: default > yt-mp4 > s-mp4 > luf-mp4)", i, got[i].Name, wantName)
}
}
}
func TestIsLikelyM3U8(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []byte
want bool
}{
{
name: "valid m3u8",
input: []byte("#EXTM3U\n#EXT-X-VERSION:3"),
want: true,
},
{
name: "with leading spaces",
input: []byte(" #EXTM3U\n"),
want: true,
},
{
name: "empty",
input: []byte{},
want: false,
},
{
name: "not m3u8",
input: []byte("<?xml version=\"1.0\""),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := isLikelyM3U8(tt.input)
if got != tt.want {
t.Errorf("isLikelyM3U8(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestIsLikelyMP4(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []byte
want bool
}{
{
name: "ftyp at offset 4",
input: []byte{0x00, 0x00, 0x00, 0x1c, 'f', 't', 'y', 'p', 0x00, 0x00, 0x00, 0x00},
want: true,
},
{
name: "short payload",
input: []byte{0x00, 0x00},
want: false,
},
{
name: "not mp4",
input: []byte{0x00, 0x00, 0x00, 0x1c, 'f', 'o', 'o', 'b'},
want: false,
},
{
name: "empty",
input: []byte{},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := isLikelyMP4(tt.input)
if got != tt.want {
t.Errorf("isLikelyMP4(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestDecryptTobeparsed(t *testing.T) {
t.Parallel()
t.Run("valid encrypted payload with first key", func(t *testing.T) {
payload := "AQAAAAABc2S7yj94zW6j4A8d9D6C3qFvYjR1hI4L6z1J3qKj5pXhKj"
decrypted, err := decryptTobeparsed(payload)
if err == nil {
var result map[string]any
if err := json.Unmarshal(decrypted, &result); err != nil {
t.Logf("decrypted (not valid json): %s", string(decrypted))
} else {
t.Logf("decrypted: %+v", result)
}
} else {
t.Logf("expected decryption to succeed or fail gracefully: %v", err)
}
})
t.Run("payload too short returns error", func(t *testing.T) {
payload := "short"
_, err := decryptTobeparsed(payload)
if err == nil {
t.Error("expected error for short payload")
}
})
t.Run("invalid base64 returns error", func(t *testing.T) {
_, err := decryptTobeparsed("not-valid-base64!!!")
if err == nil {
t.Error("expected error for invalid base64")
}
})
}
func TestTryDecryptCTR(t *testing.T) {
t.Parallel()
t.Run("decrypts correctly", func(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
block, err := aes.NewCipher(key)
if err != nil {
t.Fatalf("failed to create cipher: %v", err)
}
iv := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b}
cipherText := []byte("test plaintext ")
plainText := tryDecryptCTR(block, iv, cipherText)
_ = plainText
})
}
func TestAllAnimeClientImplementsInterfaces(t *testing.T) {
t.Parallel()
var (
_ interface {
Search(context.Context, string, string) ([]searchResult, error)
} = &allAnimeClient{}
_ interface {
GetEpisodeSources(context.Context, string, string, string) ([]StreamSource, error)
} = &allAnimeClient{}
_ interface {
GetAvailableEpisodes(context.Context, string) (AvailableEpisodes, error)
} = &allAnimeClient{}
)
t.Log("allAnimeClient implements required interfaces")
}

View File

@@ -1,481 +0,0 @@
package playback
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"maps"
"net/http"
"sort"
"strconv"
"strings"
"mal/integrations/jikan"
"mal/internal/db"
"mal/internal/middleware"
"mal/templates"
)
type Handler struct {
svc *Service
jikanClient *jikan.Client // client for Jikan API (MyAnimeList)
}
func NewHandler(svc *Service, jikanClient *jikan.Client) *Handler {
return &Handler{svc: svc, jikanClient: jikanClient}
}
// renderNotFoundPage renders the 404 page.
func renderNotFoundPage(r *http.Request, w http.ResponseWriter) {
w.WriteHeader(http.StatusNotFound)
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "not_found.gohtml", map[string]any{
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
}
// HandleWatchPage serves the anime watch page.
func (h *Handler) HandleWatchPage(w http.ResponseWriter, r *http.Request) {
// path format: /anime/123/watch
parts := strings.Split(r.URL.Path, "/")
if len(parts) < 4 {
renderNotFoundPage(r, w)
return
}
idStr := parts[2]
id, err := strconv.Atoi(idStr)
if err != nil {
renderNotFoundPage(r, w)
return
}
anime, err := h.jikanClient.GetAnimeByID(r.Context(), id)
if err != nil {
renderNotFoundPage(r, w)
return
}
allEpisodes, err := h.jikanClient.GetAllEpisodes(r.Context(), id)
if err != nil {
log.Printf("failed to fetch episodes: %v", err)
}
user := middleware.GetUser(r.Context())
// fetch user's watchlist to highlight episodes and show status
var watchlistIDs []int64
var watchlistStatus string
if user != nil {
watchlist, _ := h.svc.db.GetUserWatchList(r.Context(), user.ID)
watchlistIDs = make([]int64, len(watchlist))
for i, entry := range watchlist {
watchlistIDs[i] = entry.AnimeID
if entry.AnimeID == int64(id) {
watchlistStatus = entry.Status
}
}
}
// resolve current episode: query param > saved progress > first episode
currentEpID := r.URL.Query().Get("ep")
if currentEpID == "" {
if user != nil {
entry, err := h.svc.db.GetWatchListEntry(r.Context(), db.GetWatchListEntryParams{
UserID: user.ID,
AnimeID: int64(id),
})
if err == nil && entry.CurrentEpisode.Valid {
currentEpID = strconv.FormatInt(entry.CurrentEpisode.Int64, 10)
// redirect to include ep param for consistent URLs
http.Redirect(w, r, fmt.Sprintf("/anime/%d/watch?ep=%s", id, currentEpID), http.StatusFound)
return
}
}
currentEpID = "1"
}
mode := r.URL.Query().Get("mode")
userID := ""
if user != nil {
userID = user.ID
}
titleCandidates := []string{anime.Title}
if anime.TitleEnglish != "" && anime.TitleEnglish != anime.Title {
titleCandidates = append(titleCandidates, anime.TitleEnglish)
}
if anime.TitleJapanese != "" {
titleCandidates = append(titleCandidates, anime.TitleJapanese)
}
watchData, err := h.svc.BuildWatchPageData(r.Context(), id, titleCandidates, currentEpID, mode, userID)
if err != nil {
log.Printf("watch data error: %v", err)
}
// Fill gaps with placeholder episodes if fallback has more
if watchData.FallbackEpisodes != nil {
maxCount := 0
for _, count := range watchData.FallbackEpisodes {
if count > maxCount {
maxCount = count
}
}
epMap := make(map[int]jikan.Episode)
for _, ep := range allEpisodes {
epMap[ep.MalID] = ep
}
if maxCount > 0 {
var filled []jikan.Episode
for i := 1; i <= maxCount; i++ {
if ep, ok := epMap[i]; ok {
filled = append(filled, ep)
} else {
filled = append(filled, jikan.Episode{
MalID: i,
Episode: fmt.Sprintf("Episode %d", i),
Title: fmt.Sprintf("Episode %d", i),
})
}
}
allEpisodes = filled
}
}
sort.Slice(allEpisodes, func(i, j int) bool {
return allEpisodes[i].MalID < allEpisodes[j].MalID
})
// fetch relations to build season/movie list
relations, err := h.jikanClient.GetFullRelations(r.Context(), id)
if err != nil {
log.Printf("failed to fetch relations: %v", err)
}
type SeasonEntry struct {
MalID int
Title string
Prefix string
IsCurrent bool
}
var tvSeasons []SeasonEntry
var movies []SeasonEntry
counter := 1
for _, rel := range relations {
if strings.ToLower(rel.Anime.Type) == "tv" {
tvSeasons = append(tvSeasons, SeasonEntry{
MalID: rel.Anime.MalID,
Title: rel.Anime.DisplayTitle(),
Prefix: fmt.Sprintf("%02d", counter),
IsCurrent: rel.IsCurrent,
})
counter++
}
}
for _, rel := range relations {
if strings.ToLower(rel.Anime.Type) == "movie" {
movies = append(movies, SeasonEntry{
MalID: rel.Anime.MalID,
Title: rel.Anime.DisplayTitle(),
Prefix: "Mov",
IsCurrent: rel.IsCurrent,
})
}
}
allSeasons := append(tvSeasons, movies...)
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "watch.gohtml", map[string]any{
"Anime": anime,
"Episodes": allEpisodes,
"WatchData": watchData,
"User": user,
"CurrentPath": r.URL.Path,
"CurrentEpID": currentEpID,
"WatchlistIDs": watchlistIDs,
"WatchlistStatus": watchlistStatus,
"Seasons": allSeasons,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
}
// HandleProxy proxies media requests through the backend to avoid CORS and hide source URLs.
func (h *Handler) HandleProxy(w http.ResponseWriter, r *http.Request) {
token := r.URL.Query().Get("token")
if token == "" {
http.Error(w, "missing token", http.StatusBadRequest)
return
}
// determine proxy scope based on URL suffix
scope := proxyScopeStream
if strings.HasSuffix(r.URL.Path, "/segment") {
scope = proxyScopeSegment
} else if strings.HasSuffix(r.URL.Path, "/subtitle") {
scope = proxyScopeSubtitle
}
targetURL, referer, err := h.svc.resolveProxyToken(r.Context(), token, scope)
if err != nil {
http.Error(w, "invalid token", http.StatusForbidden)
return
}
rangeHeader := r.Header.Get("Range")
statusCode, headers, content, bodyReader, err := h.svc.ProxyStream(r.Context(), targetURL, referer, rangeHeader)
if err != nil {
log.Printf("proxy error for %s: %v", targetURL, err)
http.Error(w, "proxy failed", http.StatusBadGateway)
return
}
maps.Copy(w.Header(), headers)
w.WriteHeader(statusCode)
if bodyReader != nil {
defer func() { _ = bodyReader.Close() }()
_, _ = io.Copy(w, bodyReader)
} else {
_, _ = w.Write(content)
}
}
// HandleSaveProgress saves playback progress for a user.
func (h *Handler) HandleSaveProgress(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
user := middleware.GetUser(r.Context())
if user == nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
var req struct {
MalID int64 `json:"mal_id"`
Episode int `json:"episode"`
TimeSeconds float64 `json:"time_seconds"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// We fetch the anime info to seed the DB if it's the first time saving progress for this show
anime, err := h.jikanClient.GetAnimeByID(r.Context(), int(req.MalID))
var seed *db.UpsertAnimeParams
if err == nil {
seed = &db.UpsertAnimeParams{
ID: int64(anime.MalID),
TitleOriginal: anime.Title,
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
ImageUrl: anime.ImageURL(),
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
DurationSeconds: sql.NullFloat64{Float64: anime.DurationSeconds(), Valid: anime.DurationSeconds() > 0},
}
}
if err := h.svc.SaveProgress(r.Context(), user.ID, req.MalID, req.Episode, req.TimeSeconds, seed); err != nil {
log.Printf("failed to save progress: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
// HandleCompleteAnime marks an anime as completed for a user.
func (h *Handler) HandleCompleteAnime(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
user := middleware.GetUser(r.Context())
if user == nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
var req struct {
MalID int64 `json:"mal_id"`
Episode int `json:"episode"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
// Seed anime info if needed
anime, err := h.jikanClient.GetAnimeByID(r.Context(), int(req.MalID))
var seed *db.UpsertAnimeParams
if err == nil {
seed = &db.UpsertAnimeParams{
ID: int64(anime.MalID),
TitleOriginal: anime.Title,
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
ImageUrl: anime.ImageURL(),
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
DurationSeconds: sql.NullFloat64{Float64: anime.DurationSeconds(), Valid: anime.DurationSeconds() > 0},
}
}
if err := h.svc.CompleteAnime(r.Context(), user.ID, req.MalID, req.Episode, seed); err != nil {
log.Printf("failed to complete anime: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
// HandleEpisodeData returns episode streaming data for the player.
func (h *Handler) HandleEpisodeData(w http.ResponseWriter, r *http.Request) {
// path: /api/watch/episode/{animeId}/{episodeId}
parts := strings.Split(r.URL.Path, "/")
if len(parts) < 6 {
http.Error(w, "invalid path", http.StatusBadRequest)
return
}
animeID, err := strconv.Atoi(parts[4])
if err != nil {
http.Error(w, "invalid animeId", http.StatusBadRequest)
return
}
episodeID := parts[5]
user := middleware.GetUser(r.Context())
userID := ""
if user != nil {
userID = user.ID
}
anime, err := h.jikanClient.GetAnimeByID(r.Context(), animeID)
if err != nil {
http.Error(w, "anime not found", http.StatusNotFound)
return
}
titleCandidates := []string{anime.Title}
if anime.TitleEnglish != "" && anime.TitleEnglish != anime.Title {
titleCandidates = append(titleCandidates, anime.TitleEnglish)
}
if anime.TitleJapanese != "" {
titleCandidates = append(titleCandidates, anime.TitleJapanese)
}
watchData, err := h.svc.BuildWatchPageData(r.Context(), animeID, titleCandidates, episodeID, "", userID)
if err != nil {
http.Error(w, "failed to build watch data", http.StatusBadGateway)
return
}
w.Header().Set("Content-Type", "application/json")
if err := writeJSON(w, map[string]any{
"mal_id": watchData.MalID,
"title": watchData.Title,
"current_episode": watchData.CurrentEpisode,
"total_episodes": anime.Episodes,
"initial_mode": watchData.InitialMode,
"token": "", // The token might be per-source, wait, in Go it was per-mode?
"available_modes": watchData.AvailableModes,
"mode_sources": watchData.ModeSources,
"segments": watchData.Segments,
"episode_title": "", // Find episode title if possible
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("watch page encode error: %v", err)
}
}
}
// HandleEpisodeThumbnails returns episode list for the thumbnail strip.
func (h *Handler) HandleEpisodeThumbnails(w http.ResponseWriter, r *http.Request) {
// path: /api/watch/thumbnails/{animeId}
parts := strings.Split(r.URL.Path, "/")
if len(parts) < 5 {
http.Error(w, "invalid path", http.StatusBadRequest)
return
}
id, err := strconv.Atoi(parts[4])
if err != nil {
http.Error(w, "invalid animeId", http.StatusBadRequest)
return
}
allEpisodes, err := h.jikanClient.GetAllEpisodes(r.Context(), id)
if err != nil {
log.Printf("failed to fetch thumbnails/episodes: %v", err)
}
// Fill gaps if anime has known total
anime, _ := h.jikanClient.GetAnimeByID(r.Context(), id)
if anime.Episodes > 0 && anime.Episodes > len(allEpisodes) {
epMap := make(map[int]jikan.Episode)
for _, ep := range allEpisodes {
epMap[ep.MalID] = ep
}
var filled []jikan.Episode
for i := 1; i <= anime.Episodes; i++ {
if ep, ok := epMap[i]; ok {
filled = append(filled, ep)
} else {
filled = append(filled, jikan.Episode{
MalID: i,
Episode: fmt.Sprintf("Episode %d", i),
Title: fmt.Sprintf("Episode %d", i),
})
}
}
allEpisodes = filled
}
type Result struct {
MalID int `json:"mal_id"`
Title string `json:"title"`
}
results := make([]Result, len(allEpisodes))
for i, ep := range allEpisodes {
results[i] = Result{
MalID: ep.MalID,
Title: ep.Title,
}
}
w.Header().Set("Content-Type", "application/json")
if err := writeJSON(w, results); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("thumbnails encode error: %v", err)
}
}
}
func writeJSON(w http.ResponseWriter, v any) error {
return json.NewEncoder(w).Encode(v)
}

View File

@@ -1,26 +0,0 @@
package playback
import (
"context"
"net/http"
)
// doProxiedRequest performs an HTTP GET with standard playback headers.
func doProxiedRequest(ctx context.Context, client *http.Client, url string, referer string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", defaultUserAgent)
if referer != "" {
req.Header.Set("Referer", referer)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -1,144 +0,0 @@
package playback
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"github.com/google/uuid"
"mal/internal/db"
)
// SaveProgress updates watch progress and continue-watching state in a transaction.
func (s *Service) SaveProgress(ctx context.Context, userID string, animeID int64, episode int, timeSeconds float64, animeSeed *db.UpsertAnimeParams) error {
if strings.TrimSpace(userID) == "" || animeID <= 0 || episode <= 0 {
return errors.New("invalid save progress input")
}
txQueries, tx, err := db.BeginTx(ctx, s.sqlDB)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
if animeSeed != nil {
if _, err := txQueries.UpsertAnime(ctx, *animeSeed); err != nil {
return fmt.Errorf("failed to save anime reference: %w", err)
}
}
watchListEntry, watchListErr := txQueries.GetWatchListEntry(ctx, db.GetWatchListEntryParams{
UserID: userID,
AnimeID: animeID,
})
if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) {
return fmt.Errorf("failed to load watchlist entry: %w", watchListErr)
}
isCompleted := watchListErr == nil && watchListEntry.Status == "completed"
if !isCompleted {
if err := txQueries.SaveWatchProgress(ctx, db.SaveWatchProgressParams{
CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true},
CurrentTimeSeconds: timeSeconds,
UserID: userID,
AnimeID: animeID,
}); err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("failed to save watchlist progress: %w", err)
}
}
if isCompleted {
return tx.Commit()
}
var durationSeconds sql.NullFloat64
if animeSeed != nil {
durationSeconds = animeSeed.DurationSeconds
}
if _, err := txQueries.UpsertContinueWatchingEntry(ctx, db.UpsertContinueWatchingEntryParams{
ID: uuid.New().String(),
UserID: userID,
AnimeID: animeID,
CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true},
CurrentTimeSeconds: timeSeconds,
DurationSeconds: durationSeconds,
}); err != nil {
return fmt.Errorf("failed to upsert continue entry: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit save progress transaction: %w", err)
}
return nil
}
// CompleteAnime marks an anime as completed in the watchlist and clears continue-watching.
func (s *Service) CompleteAnime(ctx context.Context, userID string, animeID int64, episode int, animeSeed *db.UpsertAnimeParams) error {
if strings.TrimSpace(userID) == "" || animeID <= 0 || episode <= 0 {
return errors.New("invalid complete anime input")
}
txQueries, tx, err := db.BeginTx(ctx, s.sqlDB)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
watchListEntry, watchListErr := txQueries.GetWatchListEntry(ctx, db.GetWatchListEntryParams{
UserID: userID,
AnimeID: animeID,
})
if watchListErr != nil && !errors.Is(watchListErr, sql.ErrNoRows) {
return fmt.Errorf("failed to load watchlist entry: %w", watchListErr)
}
alreadyCompleted := watchListErr == nil && watchListEntry.Status == "completed"
if !alreadyCompleted {
if animeSeed != nil {
if _, err := txQueries.UpsertAnime(ctx, *animeSeed); err != nil {
return fmt.Errorf("failed to save anime reference: %w", err)
}
}
if _, err := txQueries.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{
ID: uuid.New().String(),
UserID: userID,
AnimeID: animeID,
Status: "completed",
CurrentEpisode: sql.NullInt64{Int64: 0, Valid: false},
CurrentTimeSeconds: 0,
}); err != nil {
return fmt.Errorf("failed to mark watchlist as completed: %w", err)
}
if err := txQueries.SaveWatchProgress(ctx, db.SaveWatchProgressParams{
CurrentEpisode: sql.NullInt64{Int64: 0, Valid: false},
CurrentTimeSeconds: 0,
UserID: userID,
AnimeID: animeID,
}); err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("failed to reset watch progress: %w", err)
}
}
if err := txQueries.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{
UserID: userID,
AnimeID: animeID,
}); err != nil {
return fmt.Errorf("failed to clear continue entry: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("failed to commit complete anime transaction: %w", err)
}
return nil
}

View File

@@ -1,221 +0,0 @@
package playback
import (
"context"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
)
type providerExtractor struct {
httpClient *http.Client
baseURL string
referer string
}
func newProviderExtractor() *providerExtractor {
return &providerExtractor{
httpClient: &http.Client{Timeout: 30 * time.Second},
baseURL: allAnimeBaseURL,
referer: allAnimeReferer,
}
}
// ExtractVideoLinks fetches provider page and returns stream sources.
func (e *providerExtractor) ExtractVideoLinks(ctx context.Context, providerPath string) ([]StreamSource, error) {
endpoint := e.baseURL + providerPath
var resp *http.Response
var err error
for attempt := range 3 {
if attempt > 0 {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(attempt) * 2 * time.Second):
}
}
resp, err = doProxiedRequest(ctx, e.httpClient, endpoint, e.referer)
if err == nil {
break
}
if attempt == 2 {
return nil, fmt.Errorf("fetch provider response: %w", err)
}
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) // 2MB limit
if err != nil {
return nil, fmt.Errorf("read provider response: %w", err)
}
return e.parseProviderResponse(ctx, string(body)), nil
}
// parseProviderResponse extracts stream sources from provider JSON response.
func (e *providerExtractor) parseProviderResponse(ctx context.Context, response string) []StreamSource {
sources := make([]StreamSource, 0)
providerReferer := e.referer
// extract per-source referer if present
refererPattern := regexp.MustCompile(`"Referer":"([^"]+)"`)
if match := refererPattern.FindStringSubmatch(response); len(match) >= 2 {
providerReferer = strings.ReplaceAll(match[1], `\/`, "/")
}
if providerReferer == "" {
providerReferer = e.referer
}
// extract direct link sources (mp4/embed)
linkPattern := regexp.MustCompile(`"link":"([^"]+)","resolutionStr":"([^"]+)"`)
for _, match := range linkPattern.FindAllStringSubmatch(response, -1) {
if len(match) < 3 {
continue
}
link := strings.ReplaceAll(match[1], `\/`, "/")
quality := strings.TrimSpace(match[2])
sourceType := detectStreamType(link)
if sourceType == "unknown" {
sourceType = detectEmbedType(link)
}
sources = append(sources, StreamSource{
URL: link,
Quality: quality,
Provider: "wixmp",
Type: sourceType,
Referer: providerReferer,
})
}
// extract HLS playlist sources
hlsPattern := regexp.MustCompile(`"url":"([^"]+)","hardsub_lang":"en-US"`)
for _, match := range hlsPattern.FindAllStringSubmatch(response, -1) {
if len(match) < 2 {
continue
}
playlistURL := strings.ReplaceAll(match[1], `\/`, "/")
if strings.Contains(playlistURL, "master.m3u8") {
parsed, err := e.parseM3U8(ctx, playlistURL, providerReferer)
if err == nil {
sources = append(sources, parsed...)
}
continue
}
sources = append(sources, StreamSource{
URL: playlistURL,
Quality: "auto",
Provider: "hls",
Type: "m3u8",
Referer: providerReferer,
})
}
// extract subtitles and attach to all sources
subtitlePattern := regexp.MustCompile(`"subtitles":\[(.*?)\]`)
if subtitleMatch := subtitlePattern.FindStringSubmatch(response); len(subtitleMatch) >= 2 {
subtitles := make([]Subtitle, 0)
subtitleEntryPattern := regexp.MustCompile(`"lang":"([^"]+)".*?"src":"([^"]+)"`)
for _, entry := range subtitleEntryPattern.FindAllStringSubmatch(subtitleMatch[1], -1) {
if len(entry) < 3 {
continue
}
subtitles = append(subtitles, Subtitle{
Lang: strings.TrimSpace(entry[1]),
URL: strings.ReplaceAll(entry[2], `\/`, "/"),
})
}
if len(subtitles) > 0 {
for idx := range sources {
sources[idx].Subtitles = subtitles
}
}
}
return sources
}
// parseM3U8 fetches a master playlist and extracts individual stream URLs with bandwidth-derived quality.
func (e *providerExtractor) parseM3U8(ctx context.Context, masterURL string, referer string) ([]StreamSource, error) {
resp, err := doProxiedRequest(ctx, e.httpClient, masterURL, referer)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024)) // 512KB limit
if err != nil {
return nil, err
}
lines := strings.Split(string(body), "\n")
baseURL := masterURL
if idx := strings.LastIndex(masterURL, "/"); idx >= 0 {
baseURL = masterURL[:idx+1]
}
currentBandwidth := 0
sources := make([]StreamSource, 0)
bwPattern := regexp.MustCompile(`BANDWIDTH=(\d+)`)
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "#EXT-X-STREAM-INF") {
match := bwPattern.FindStringSubmatch(trimmed)
if len(match) >= 2 {
value, convErr := strconv.Atoi(match[1])
if convErr == nil {
currentBandwidth = value
}
}
continue
}
// skip empty lines and non-stream lines
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
continue
}
streamURL := trimmed
if !strings.HasPrefix(streamURL, "http://") && !strings.HasPrefix(streamURL, "https://") {
streamURL = baseURL + streamURL
}
quality := "auto"
kbps := currentBandwidth / 1000
switch {
case kbps >= 8000:
quality = "1080p"
case kbps >= 5000:
quality = "720p"
case kbps >= 2500:
quality = "480p"
case kbps > 0:
quality = "360p"
}
sources = append(sources, StreamSource{
URL: streamURL,
Quality: quality,
Provider: "hls",
Type: "m3u8",
Referer: referer,
})
}
return sources, nil
}

View File

@@ -1,356 +0,0 @@
package playback
import (
"bufio"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
)
const (
proxyStreamTokenTTL = 2 * time.Hour
proxySegmentTokenTTL = 6 * time.Hour
proxySubtitleTokenTTL = 6 * time.Hour
)
type proxyScope string
const (
proxyScopeStream proxyScope = "stream"
proxyScopeSegment proxyScope = "segment"
proxyScopeSubtitle proxyScope = "subtitle"
)
type proxyTokenPayload struct {
TargetURL string `json:"u"`
Referer string `json:"r,omitempty"`
Scope string `json:"s"`
ExpiresAt int64 `json:"exp"`
}
type proxyTokenSigner struct {
secret []byte
}
func newProxyTokenSigner(secret string) (*proxyTokenSigner, error) {
trimmed := strings.TrimSpace(secret)
if trimmed == "" {
return nil, errors.New("proxy token secret is required")
}
if len(trimmed) < 32 {
return nil, errors.New("proxy token secret must be at least 32 characters")
}
return &proxyTokenSigner{secret: []byte(trimmed)}, nil
}
func (s *proxyTokenSigner) Sign(payload proxyTokenPayload) (string, error) {
body, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("marshal proxy token payload: %w", err)
}
mac := hmac.New(sha256.New, s.secret)
mac.Write(body)
signature := mac.Sum(nil)
// format: payload.signature (both base64url encoded)
encodedBody := base64.RawURLEncoding.EncodeToString(body)
encodedSignature := base64.RawURLEncoding.EncodeToString(signature)
return encodedBody + "." + encodedSignature, nil
}
func (s *proxyTokenSigner) Verify(token string) (proxyTokenPayload, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return proxyTokenPayload{}, errors.New("invalid proxy token format")
}
body, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return proxyTokenPayload{}, errors.New("invalid proxy token payload")
}
signature, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return proxyTokenPayload{}, errors.New("invalid proxy token signature")
}
mac := hmac.New(sha256.New, s.secret)
mac.Write(body)
expected := mac.Sum(nil)
if !hmac.Equal(signature, expected) { // constant-time comparison
return proxyTokenPayload{}, errors.New("invalid proxy token signature")
}
var payload proxyTokenPayload
if err := json.Unmarshal(body, &payload); err != nil {
return proxyTokenPayload{}, errors.New("invalid proxy token payload")
}
if payload.ExpiresAt <= time.Now().Unix() {
return proxyTokenPayload{}, errors.New("proxy token expired")
}
return payload, nil
}
func (s *Service) buildClientModeSources(modeSources map[string]ModeSource) (map[string]ModeSource, error) {
clientModeSources := make(map[string]ModeSource, len(modeSources))
for mode, source := range modeSources {
// wrap stream url with proxy token
streamToken, err := s.issueProxyToken(source.URL, source.Referer, proxyScopeStream)
if err != nil {
return nil, err
}
subtitles := make([]SubtitleItem, 0, len(source.Subtitles))
for _, subtitle := range source.Subtitles {
targetURL := strings.TrimSpace(subtitle.URL)
if targetURL == "" {
continue
}
token, err := s.issueProxyToken(targetURL, source.Referer, proxyScopeSubtitle)
if err != nil {
return nil, err
}
subtitles = append(subtitles, SubtitleItem{
Lang: subtitle.Lang,
Token: token,
})
}
clientModeSources[mode] = ModeSource{
Token: streamToken,
Subtitles: subtitles,
Qualities: source.Qualities,
}
}
return clientModeSources, nil
}
func (s *Service) issueProxyToken(targetURL string, referer string, scope proxyScope) (string, error) {
normalizedTarget, err := normalizeProxyURL(targetURL)
if err != nil {
return "", err
}
normalizedReferer := ""
if strings.TrimSpace(referer) != "" {
refererURL, refererErr := normalizeProxyURL(referer)
if refererErr == nil {
normalizedReferer = refererURL
}
}
return s.proxyTokens.Sign(proxyTokenPayload{
TargetURL: normalizedTarget,
Referer: normalizedReferer,
Scope: string(scope),
ExpiresAt: time.Now().Add(proxyTokenTTL(scope)).Unix(),
})
}
// proxyTokenTTLs defines ttl per scope type.
var proxyTokenTTLs = map[proxyScope]time.Duration{
proxyScopeStream: proxyStreamTokenTTL,
proxyScopeSegment: proxySegmentTokenTTL,
proxyScopeSubtitle: proxySubtitleTokenTTL,
}
func proxyTokenTTL(scope proxyScope) time.Duration {
if ttl, ok := proxyTokenTTLs[scope]; ok {
return ttl
}
return proxyStreamTokenTTL
}
func (s *Service) resolveProxyToken(ctx context.Context, token string, scope proxyScope) (string, string, error) {
payload, err := s.proxyTokens.Verify(token)
if err != nil {
return "", "", err
}
if payload.Scope != string(scope) {
return "", "", errors.New("proxy token scope mismatch")
}
normalizedTarget, err := normalizeProxyURL(payload.TargetURL)
if err != nil {
return "", "", err
}
if err := s.ensurePublicProxyTarget(ctx, normalizedTarget); err != nil {
return "", "", err
}
// resolve referer only if it passes public target check
normalizedReferer := ""
if strings.TrimSpace(payload.Referer) != "" {
refererURL, refererErr := normalizeProxyURL(payload.Referer)
if refererErr == nil {
if ensureErr := s.ensurePublicProxyTarget(ctx, refererURL); ensureErr == nil {
normalizedReferer = refererURL
}
}
}
return normalizedTarget, normalizedReferer, nil
}
// normalizeProxyURL validates and canonicalizes a proxy target URL.
func normalizeProxyURL(rawURL string) (string, error) {
parsed, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return "", errors.New("invalid proxy target")
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return "", errors.New("invalid proxy target scheme")
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host == "" {
return "", errors.New("invalid proxy target host")
}
// block localhost and .local TLD
if host == "localhost" || strings.HasSuffix(host, ".localhost") || strings.HasSuffix(host, ".local") {
return "", errors.New("localhost targets are not allowed")
}
ip := net.ParseIP(host)
if ip != nil && isBlockedProxyIP(ip) {
return "", errors.New("private proxy targets are not allowed")
}
return parsed.String(), nil
}
// isBlockedProxyIP checks for loopback, private, multicast, and unspecified addresses.
func isBlockedProxyIP(ip net.IP) bool {
return ip.IsLoopback() ||
ip.IsPrivate() ||
ip.IsMulticast() ||
ip.IsLinkLocalMulticast() ||
ip.IsLinkLocalUnicast() ||
ip.IsUnspecified()
}
// ensurePublicProxyTarget validates that the target host resolves to a public IP.
// results are cached to avoid repeated DNS lookups.
func (s *Service) ensurePublicProxyTarget(ctx context.Context, rawURL string) error {
parsed, err := url.Parse(rawURL)
if err != nil {
return errors.New("invalid proxy target")
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return errors.New("invalid proxy target host")
}
// direct IP already checked by normalizeProxyURL
if ip := net.ParseIP(host); ip != nil {
if isBlockedProxyIP(ip) {
return errors.New("private proxy targets are not allowed")
}
return nil
}
// check cache first
cached, ok := s.proxyHostCache.Get(host)
if ok {
if cached.Allowed {
return nil
}
return errors.New("private proxy targets are not allowed")
}
// DNS resolution for hostname
resolvedIPs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil || len(resolvedIPs) == 0 {
return errors.New("proxy target lookup failed")
}
allowed := true
for _, resolved := range resolvedIPs {
if isBlockedProxyIP(resolved.IP) {
allowed = false
break
}
}
s.proxyHostCache.Add(host, proxyHostCacheItem{
Allowed: allowed,
})
if !allowed {
return errors.New("private proxy targets are not allowed")
}
return nil
}
// rewritePlaylistWithTokens replaces segment URLs with proxy tokens for HLS playlists.
func (s *Service) rewritePlaylistWithTokens(ctx context.Context, content string, baseURL string, referer string) (string, error) {
base, err := url.Parse(baseURL)
if err != nil {
return "", err
}
var out strings.Builder
scanner := bufio.NewScanner(strings.NewReader(content))
for scanner.Scan() {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
line := scanner.Text()
trimmed := strings.TrimSpace(line)
// preserve comments and empty lines
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
out.WriteString(line)
out.WriteString("\n")
continue
}
relativeURL, parseErr := url.Parse(trimmed)
if parseErr != nil {
out.WriteString(line)
out.WriteString("\n")
continue
}
absoluteURL := base.ResolveReference(relativeURL).String()
token, tokenErr := s.issueProxyToken(absoluteURL, referer, proxyScopeSegment)
if tokenErr != nil {
return "", tokenErr
}
proxied := "/watch/proxy/segment?token=" + url.QueryEscape(token)
out.WriteString(proxied)
out.WriteString("\n")
}
if err := scanner.Err(); err != nil {
return "", err
}
return out.String(), nil
}

View File

@@ -1,48 +0,0 @@
package playback
import (
"context"
"testing"
"mal/internal/db"
)
func TestNormalizeProxyURLRejectsLocalhost(t *testing.T) {
t.Parallel()
_, err := normalizeProxyURL("http://localhost:8080/private")
if err == nil {
t.Fatal("expected localhost URL to be rejected")
}
}
func TestNormalizeProxyURLRejectsPrivateIP(t *testing.T) {
t.Parallel()
_, err := normalizeProxyURL("http://192.168.1.10/stream")
if err == nil {
t.Fatal("expected private IP URL to be rejected")
}
}
func TestProxyTokenScopeValidation(t *testing.T) {
t.Parallel()
service, err := NewService(&fakeProxyQuerier{}, nil, Config{ProxyTokenSecret: "0123456789abcdef0123456789abcdef"})
if err != nil {
t.Fatalf("failed to create service: %v", err)
}
token, err := service.issueProxyToken("https://example.com/playlist.m3u8", "", proxyScopeStream)
if err != nil {
t.Fatalf("failed to issue token: %v", err)
}
_, _, err = service.resolveProxyToken(context.Background(), token, proxyScopeSegment)
if err == nil {
t.Fatal("expected scope mismatch error")
}
}
type fakeProxyQuerier struct {
db.Querier
}

View File

@@ -1,392 +0,0 @@
package playback
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"mal/internal/db"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/golang-lru/v2"
)
const (
providerProbeTimeout = 3 * time.Second
)
type Service struct {
allAnimeClient *allAnimeClient
httpClient *http.Client
sqlDB *sql.DB
db db.Querier
proxyTokens *proxyTokenSigner
proxyHostCache *lru.Cache[string, proxyHostCacheItem]
showResolution *lru.Cache[int, showResolutionCacheItem]
playbackDataCache *lru.Cache[string, playbackDataCacheItem]
}
type Config struct {
ProxyTokenSecret string
}
type sourceScore struct {
source StreamSource
total int
typeScore int
providerScore int
qualityScore int
refererScore int
}
type showResolutionCacheItem struct {
ShowID string
Title string
}
type playbackDataCacheItem struct {
Data playbackBaseData
}
type playbackBaseData struct {
Title string
AvailableModes []string
ModeSources map[string]ModeSource
Segments []SkipSegment
FallbackEpisodes map[string]int
}
type modeSourceResult struct {
Mode string
Source ModeSource
OK bool
}
type searchModeResult struct {
Mode string
Results []searchResult
Err error
}
type directProbeResult struct {
Playable bool
ContentType string
}
type proxyHostCacheItem struct {
Allowed bool
}
type userPlaybackState struct {
CurrentStatus string
StartTimeSeconds float64
}
// NewService initializes the playback service with db and sql connections.
func NewService(db db.Querier, sqlDB *sql.DB, cfg Config) (*Service, error) {
proxyTokens, err := newProxyTokenSigner(cfg.ProxyTokenSecret)
if err != nil {
return nil, fmt.Errorf("failed to initialize proxy token signer: %w", err)
}
showResolution, err := lru.New[int, showResolutionCacheItem](5000)
if err != nil {
return nil, err
}
playbackDataCache, err := lru.New[string, playbackDataCacheItem](500)
if err != nil {
return nil, err
}
proxyHostCache, err := lru.New[string, proxyHostCacheItem](1000)
if err != nil {
return nil, err
}
return &Service{
allAnimeClient: newAllAnimeClient(),
httpClient: &http.Client{Timeout: 12 * time.Second},
sqlDB: sqlDB,
db: db,
proxyTokens: proxyTokens,
proxyHostCache: proxyHostCache,
showResolution: showResolution,
playbackDataCache: playbackDataCache,
}, nil
}
// BuildWatchPageData resolves show metadata and sources for a given MAL ID and episode.
func (s *Service) BuildWatchPageData(ctx context.Context, malID int, titleCandidates []string, episode string, mode string, userID string) (WatchPageData, error) {
if malID <= 0 {
return WatchPageData{}, errors.New("invalid mal id")
}
normalizedMode := normalizeMode(mode)
if normalizedMode == "" {
normalizedMode = "dub"
}
normalizedEpisode := strings.TrimSpace(episode)
if normalizedEpisode == "" {
normalizedEpisode = "1"
}
userStateCh := s.fetchUserPlaybackStateAsync(ctx, userID, malID, normalizedEpisode)
cacheKey := playbackDataCacheKey(malID, normalizedEpisode)
baseData, cacheHit := s.getPlaybackBaseDataCache(cacheKey)
if !cacheHit {
showID, resolvedTitle, err := s.resolveShowCached(ctx, malID, titleCandidates)
if err != nil {
return WatchPageData{}, err
}
modeSources, segments := s.fetchPlaybackSourcesAndSegments(ctx, showID, malID, normalizedEpisode)
if len(modeSources) == 0 {
return WatchPageData{}, errors.New("no direct playable sources available")
}
fallbackEpisodes := make(map[string]int)
if counts, err := s.allAnimeClient.GetAvailableEpisodes(ctx, showID); err == nil {
fallbackEpisodes["sub"] = len(counts.Sub)
fallbackEpisodes["dub"] = len(counts.Dub)
fallbackEpisodes["raw"] = len(counts.Raw)
}
watchTitle := strings.TrimSpace(resolvedTitle)
if watchTitle == "" {
watchTitle = firstNonEmptyTitle(titleCandidates)
}
if watchTitle == "" {
watchTitle = fmt.Sprintf("MAL #%d", malID)
}
baseData = playbackBaseData{
Title: watchTitle,
AvailableModes: availableModes(modeSources),
ModeSources: modeSources,
Segments: segments,
FallbackEpisodes: fallbackEpisodes,
}
s.setPlaybackBaseDataCache(cacheKey, baseData)
}
initialMode := selectInitialMode(normalizedMode, baseData.ModeSources)
clientModeSources, err := s.buildClientModeSources(baseData.ModeSources)
if err != nil {
return WatchPageData{}, err
}
if _, ok := clientModeSources[initialMode]; !ok {
return WatchPageData{}, errors.New("stream mode unavailable")
}
segments := baseData.Segments
if segments == nil {
segments = []SkipSegment{}
}
userState := userPlaybackState{}
if userStateCh != nil {
userState = <-userStateCh
}
return WatchPageData{
MalID: malID,
Title: baseData.Title,
CurrentEpisode: normalizedEpisode,
StartTimeSeconds: userState.StartTimeSeconds,
CurrentStatus: userState.CurrentStatus,
InitialMode: initialMode,
AvailableModes: cloneSlice(baseData.AvailableModes),
ModeSources: clientModeSources,
Segments: cloneSlice(segments),
FallbackEpisodes: baseData.FallbackEpisodes,
}, nil
}
func playbackDataCacheKey(malID int, episode string) string {
return fmt.Sprintf("%d:%s", malID, episode)
}
func (s *Service) fetchUserPlaybackStateAsync(ctx context.Context, userID string, malID int, episode string) <-chan userPlaybackState {
if userID == "" || s.db == nil {
return nil
}
resultCh := make(chan userPlaybackState, 1)
go func() {
state := userPlaybackState{}
entry, err := s.db.GetWatchListEntry(ctx, db.GetWatchListEntryParams{
UserID: userID,
AnimeID: int64(malID),
})
if err == nil {
state.CurrentStatus = entry.Status
if entry.CurrentEpisode.Valid && strconv.FormatInt(entry.CurrentEpisode.Int64, 10) == episode && entry.CurrentTimeSeconds > 0 {
state.StartTimeSeconds = entry.CurrentTimeSeconds
}
}
if state.StartTimeSeconds <= 0 {
continueEntry, continueErr := s.db.GetContinueWatchingEntry(ctx, db.GetContinueWatchingEntryParams{
UserID: userID,
AnimeID: int64(malID),
})
if continueErr == nil && continueEntry.CurrentEpisode.Valid && strconv.FormatInt(continueEntry.CurrentEpisode.Int64, 10) == episode && continueEntry.CurrentTimeSeconds > 0 {
state.StartTimeSeconds = continueEntry.CurrentTimeSeconds
}
}
resultCh <- state
}()
return resultCh
}
func (s *Service) getPlaybackBaseDataCache(key string) (playbackBaseData, bool) {
item, ok := s.playbackDataCache.Get(key)
if !ok {
return playbackBaseData{}, false
}
return clonePlaybackBaseData(item.Data), true
}
func (s *Service) setPlaybackBaseDataCache(key string, data playbackBaseData) {
s.playbackDataCache.Add(key, playbackDataCacheItem{
Data: clonePlaybackBaseData(data),
})
}
func (s *Service) resolveShowCached(ctx context.Context, malID int, titleCandidates []string) (string, string, error) {
if item, ok := s.showResolution.Get(malID); ok && strings.TrimSpace(item.ShowID) != "" {
return item.ShowID, item.Title, nil
}
showID, resolvedTitle, err := s.resolveShow(ctx, malID, titleCandidates)
if err != nil {
return "", "", err
}
s.showResolution.Add(malID, showResolutionCacheItem{
ShowID: showID,
Title: resolvedTitle,
})
return showID, resolvedTitle, nil
}
// fetchPlaybackSourcesAndSegments resolves sources for both dub and sub modes concurrently.
func (s *Service) fetchPlaybackSourcesAndSegments(ctx context.Context, showID string, malID int, episode string) (map[string]ModeSource, []SkipSegment) {
modeCh := make(chan modeSourceResult, 2)
probeCache := make(map[string]directProbeResult)
probeCacheMu := sync.Mutex{}
// parallel fetch for both modes
for _, mode := range []string{"dub", "sub"} {
modeValue := mode
go func() {
resolved, err := s.resolveModeSourceWithCache(ctx, showID, episode, modeValue, "best", probeCache, &probeCacheMu)
if err != nil {
log.Printf("playback source resolution failed for mode=%s showID=%s episode=%s: %v", modeValue, showID, episode, err)
modeCh <- modeSourceResult{Mode: modeValue, OK: false}
return
}
if strings.ToLower(resolved.Type) == "embed" {
modeCh <- modeSourceResult{Mode: modeValue, OK: false}
return
}
modeCh <- modeSourceResult{
Mode: modeValue,
Source: ModeSource{
URL: resolved.URL,
Referer: resolved.Referer,
Subtitles: toSubtitleItems(resolved),
Qualities: toQualities(resolved.AvailableQualities),
},
OK: true,
}
}()
}
segmentsCh := make(chan []SkipSegment, 1)
go func() {
segmentsCh <- s.fetchSkipSegments(ctx, malID, episode)
}()
modeSources := make(map[string]ModeSource)
// collect results from both mode goroutines
for range 2 {
result := <-modeCh
if !result.OK {
continue
}
modeSources[result.Mode] = result.Source
}
segments := <-segmentsCh
return modeSources, segments
}
func clonePlaybackBaseData(data playbackBaseData) playbackBaseData {
return playbackBaseData{
Title: data.Title,
AvailableModes: cloneSlice(data.AvailableModes),
ModeSources: cloneModeSources(data.ModeSources),
Segments: cloneSlice(data.Segments),
FallbackEpisodes: data.FallbackEpisodes,
}
}
func toQualities(sources []StreamSource) []string {
seen := make(map[string]struct{})
var qualities []string
for _, s := range sources {
q := strings.TrimSpace(s.Quality)
if q == "" || q == "auto" {
continue
}
if _, ok := seen[q]; !ok {
seen[q] = struct{}{}
qualities = append(qualities, q)
}
}
return qualities
}
func cloneSlice[T any](items []T) []T {
if items == nil {
return []T{}
}
if len(items) == 0 {
return []T{}
}
cloned := make([]T, len(items))
copy(cloned, items)
return cloned
}
func cloneModeSources(modeSources map[string]ModeSource) map[string]ModeSource {
if len(modeSources) == 0 {
return nil
}
cloned := make(map[string]ModeSource, len(modeSources))
for mode, source := range modeSources {
cloned[mode] = ModeSource{
URL: source.URL,
Referer: source.Referer,
Subtitles: cloneSlice(source.Subtitles),
Qualities: cloneSlice(source.Qualities),
}
}
return cloned
}

View File

@@ -1,74 +0,0 @@
package playback
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
)
// fetchSkipSegments queries aniskip API for OP/ED skip times.
// returns nil if the API is unavailable or has no data.
func (s *Service) fetchSkipSegments(ctx context.Context, malID int, episode string) []SkipSegment {
if malID <= 0 || strings.TrimSpace(episode) == "" {
return nil
}
endpoint := fmt.Sprintf("https://api.aniskip.com/v1/skip-times/%s/%s?types=op&types=ed", url.PathEscape(strconv.Itoa(malID)), url.PathEscape(episode))
resp, err := doProxiedRequest(ctx, s.httpClient, endpoint, "")
if err != nil {
return nil
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024))
if err != nil {
return nil
}
type resultItem struct {
SkipType string `json:"skip_type"`
Interval struct {
StartTime float64 `json:"start_time"`
EndTime float64 `json:"end_time"`
} `json:"interval"`
}
type apiResponse struct {
Found bool `json:"found"`
Result []resultItem `json:"results"`
}
var parsed apiResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil
}
// filter to valid OP/ED segments
segments := make([]SkipSegment, 0, len(parsed.Result))
for _, item := range parsed.Result {
if item.Interval.EndTime <= item.Interval.StartTime {
continue
}
t := strings.ToLower(item.SkipType)
if t != "op" && t != "ed" {
continue
}
segments = append(segments, SkipSegment{
Type: t,
Start: item.Interval.StartTime,
End: item.Interval.EndTime,
})
}
return segments
}

View File

@@ -1,119 +0,0 @@
package playback
import (
"context"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"time"
)
// ProxyStream fetches a stream URL and returns the response.
// retries on failure, rewrites m3u8 playlists to include auth tokens.
func (s *Service) ProxyStream(ctx context.Context, targetURL string, referer string, rangeHeader string) (int, http.Header, []byte, io.ReadCloser, error) {
const maxRetries = 2
const retryDelay = 500 * time.Millisecond
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
select {
case <-ctx.Done():
return 0, nil, nil, nil, ctx.Err()
case <-time.After(retryDelay):
}
log.Printf("retrying proxy request for %s (attempt %d/%d)", targetURL, attempt, maxRetries)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil)
if err != nil {
return 0, nil, nil, nil, fmt.Errorf("invalid upstream url: %w", err)
}
if referer != "" {
req.Header.Set("Referer", referer)
}
req.Header.Set("User-Agent", defaultUserAgent)
if rangeHeader != "" {
req.Header.Set("Range", rangeHeader)
}
resp, err := s.httpClient.Do(req)
if err != nil {
lastErr = err
continue
}
return s.handleProxyResponse(ctx, resp, targetURL, referer)
}
return 0, nil, nil, nil, fmt.Errorf("upstream request failed after %d retries: %w", maxRetries+1, lastErr)
}
// handleProxyResponse processes the upstream response.
// rewrites m3u8 playlists to proxy through our backend.
func (s *Service) handleProxyResponse(ctx context.Context, resp *http.Response, targetURL string, referer string) (int, http.Header, []byte, io.ReadCloser, error) {
// check if response is an m3u8 playlist that needs rewriting
if isM3U8(targetURL, resp.Header.Get("Content-Type")) {
defer func() { _ = resp.Body.Close() }()
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
if readErr != nil {
return 0, nil, nil, nil, fmt.Errorf("read playlist failed: %w", readErr)
}
rewritten, rewriteErr := s.rewritePlaylistWithTokens(ctx, string(body), targetURL, referer)
if rewriteErr != nil {
return 0, nil, nil, nil, fmt.Errorf("rewrite playlist failed: %w", rewriteErr)
}
headers := cloneHeaders(resp.Header)
headers.Del("Content-Length")
headers.Del("Transfer-Encoding")
headers.Set("Content-Type", "application/vnd.apple.mpegurl")
headers.Set("Content-Length", strconv.Itoa(len(rewritten)))
return resp.StatusCode, headers, []byte(rewritten), nil, nil
}
// for binary streams, remove chunked encoding and return body reader
headers := cloneHeaders(resp.Header)
headers.Del("Transfer-Encoding")
return resp.StatusCode, headers, nil, resp.Body, nil
}
// isM3U8 checks if the response is an m3u8 playlist by URL or content-type.
func isM3U8(targetURL string, contentType string) bool {
if strings.Contains(strings.ToLower(targetURL), ".m3u8") {
return true
}
lowerType := strings.ToLower(contentType)
return strings.Contains(lowerType, "application/vnd.apple.mpegurl") || strings.Contains(lowerType, "application/x-mpegurl")
}
var hopHeaders = map[string]struct{}{
"connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailers": {},
"upgrade": {},
}
// cloneHeaders copies headers, filtering out hop-by-hop headers.
// hop-by-hop headers are specific to a single transport connection.
func cloneHeaders(src http.Header) http.Header {
dst := make(http.Header)
for key, values := range src {
if _, ok := hopHeaders[strings.ToLower(key)]; ok {
continue
}
for _, value := range values {
dst.Add(key, value)
}
}
return dst
}

View File

@@ -1,188 +0,0 @@
package playback
import (
"bytes"
"errors"
"sort"
"strconv"
"strings"
)
func rankSources(sources []StreamSource, quality string) ([]sourceScore, error) {
filtered := make([]StreamSource, 0, len(sources))
seen := make(map[string]struct{})
for _, source := range sources {
if source.URL == "" {
continue
}
if _, exists := seen[source.URL]; exists {
continue
}
seen[source.URL] = struct{}{}
filtered = append(filtered, source)
}
if len(filtered) == 0 {
return nil, errors.New("no playable sources available")
}
targetQuality := normalizeQuality(quality)
scored := make([]sourceScore, 0, len(filtered))
for _, source := range filtered {
typeScore := lookupPriority(sourceTypePriority, source.Type, 200)
providerScore := lookupPriority(providerPriority, source.Provider, 60)
qualityScore := sourceQualityPriority(source.Quality, targetQuality)
refererScore := 0
if source.Referer != "" {
refererScore = 20
}
total := typeScore + providerScore + qualityScore + refererScore
scored = append(scored, sourceScore{
source: source,
total: total,
typeScore: typeScore,
providerScore: providerScore,
qualityScore: qualityScore,
refererScore: refererScore,
})
}
// stable sort to preserve insertion order for equal scores
sort.SliceStable(scored, func(i int, j int) bool {
return scored[i].total > scored[j].total
})
return scored, nil
}
func normalizeQuality(quality string) string {
lower := strings.ToLower(strings.TrimSpace(quality))
if lower == "" {
return "best"
}
return lower
}
var sourceTypePriority = map[string]int{
"mp4": 500,
"m3u8": 450,
"unknown": 300,
"embed": 100,
}
var providerPriority = map[string]int{
"s-mp4": 120,
"default": 115,
"luf-mp4": 110,
"vid-mp4": 105,
"yt-mp4": 100,
"mp4": 95,
"uv-mp4": 90,
"hls": 80,
"sw": 40,
"ok": 35,
"ss-hls": 30,
}
func lookupPriority(m map[string]int, key string, fallback int) int {
if p, ok := m[strings.ToLower(key)]; ok {
return p
}
return fallback
}
// sourceQualityPriority scores quality match: exact match gets boost, mismatch gets penalty.
func sourceQualityPriority(sourceQuality string, targetQuality string) int {
qualityValue := parseQualityValue(sourceQuality)
switch targetQuality {
case "best":
return qualityValue
case "worst":
return -qualityValue
default:
if qualityMatches(sourceQuality, targetQuality) {
return 2000 + qualityValue
}
return -300 + qualityValue
}
}
// qualityMatches checks if source matches target by substring or extracted digits.
func qualityMatches(sourceQuality string, targetQuality string) bool {
sourceLower := strings.ToLower(sourceQuality)
targetLower := strings.ToLower(targetQuality)
if sourceLower == "" {
return false
}
if strings.Contains(sourceLower, targetLower) {
return true
}
return extractDigits(sourceLower) == extractDigits(targetLower)
}
// parseQualityValue extracts numeric value from quality string.
func parseQualityValue(rawQuality string) int {
lower := strings.ToLower(rawQuality)
if lower == "auto" {
return 240
}
digits := extractDigits(lower)
if digits == "" {
return 0
}
value, err := strconv.Atoi(digits)
if err != nil {
return 0
}
return value
}
// extractDigits reads leading digits until a non-digit or break condition.
func extractDigits(value string) string {
var digits []byte
for _, char := range value {
if char >= '0' && char <= '9' {
digits = append(digits, byte(char))
} else if len(digits) > 0 {
break
}
}
return string(digits)
}
// normalizeSourceTypeFromProbe overrides source type based on Content-Type header.
func normalizeSourceTypeFromProbe(source StreamSource, contentType string) StreamSource {
lower := strings.ToLower(contentType)
switch {
case strings.Contains(lower, "video/mp4"):
source.Type = "mp4"
case strings.Contains(lower, "mpegurl"):
source.Type = "m3u8"
}
return source
}
// isLikelyMP4 checks ftyp box header (bytes 4-8 of mp4 files).
func isLikelyMP4(payload []byte) bool {
if len(payload) < 12 {
return false
}
return bytes.Equal(payload[4:8], []byte("ftyp"))
}
// isLikelyM3U8 checks for m3u8 file header.
func isLikelyM3U8(payload []byte) bool {
trimmed := strings.TrimSpace(string(payload))
return strings.HasPrefix(trimmed, "#EXTM3U")
}

View File

@@ -1,491 +0,0 @@
package playback
import (
"testing"
)
func TestRankSources(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sources []StreamSource
quality string
wantErr bool
}{
{
name: "empty sources returns error",
sources: nil,
quality: "best",
wantErr: true,
},
{
name: "filters empty URLs",
sources: []StreamSource{
{URL: "", Type: "mp4"},
{URL: "https://example.com/v.mp4", Type: "mp4"},
},
quality: "best",
wantErr: false,
},
{
name: "deduplicates URLs",
sources: []StreamSource{
{URL: "https://a.com/v.mp4", Type: "mp4"},
{URL: "https://b.com/v.mp4", Type: "m3u8"},
{URL: "https://a.com/v.mp4", Type: "mp4"},
},
quality: "best",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := rankSources(tt.sources, tt.quality)
if (err != nil) != tt.wantErr {
t.Errorf("rankSources() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestRankSourcesOrdering(t *testing.T) {
t.Parallel()
sources := []StreamSource{
{URL: "https://embed.com/v.mp4", Type: "embed", Provider: "streamwish"},
{URL: "https://mp4.com/v.mp4", Type: "mp4", Provider: "s-mp4"},
{URL: "https://m3u8.com/v.m3u8", Type: "m3u8", Provider: "default"},
{URL: "https://unknown.com/v.mp4", Type: "unknown", Provider: "other"},
}
ranked, err := rankSources(sources, "best")
if err != nil {
t.Fatalf("rankSources() error = %v", err)
}
if len(ranked) != 4 {
t.Fatalf("got %d sources, want 4", len(ranked))
}
if ranked[0].source.Type != "mp4" {
t.Errorf("ranked[0] = %q, want mp4 (type priority: mp4 > m3u8 > unknown > embed)", ranked[0].source.Type)
}
if ranked[1].source.Type != "m3u8" {
t.Errorf("ranked[1] = %q, want m3u8", ranked[1].source.Type)
}
}
func TestRankSourcesWithQuality(t *testing.T) {
t.Parallel()
sources := []StreamSource{
{URL: "https://a.com/v.mp4", Quality: "1080p", Type: "mp4"},
{URL: "https://b.com/v.mp4", Quality: "720p", Type: "mp4"},
{URL: "https://c.com/v.mp4", Quality: "480p", Type: "mp4"},
}
ranked, err := rankSources(sources, "1080p")
if err != nil {
t.Fatalf("rankSources() error = %v", err)
}
if ranked[0].source.Quality != "1080p" {
t.Errorf("ranked[0].Quality = %q, want 1080p", ranked[0].source.Quality)
}
}
func TestNormalizeQuality(t *testing.T) {
t.Parallel()
tests := []struct {
name string
quality string
wantNorm string
}{
{
name: "empty returns best",
quality: "",
wantNorm: "best",
},
{
name: "lowercase best",
quality: "BEST",
wantNorm: "best",
},
{
name: "with spaces",
quality: " 720p ",
wantNorm: "720p",
},
{
name: "worst",
quality: "worst",
wantNorm: "worst",
},
{
name: "specific quality",
quality: "1080p",
wantNorm: "1080p",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := normalizeQuality(tt.quality)
if got != tt.wantNorm {
t.Errorf("normalizeQuality(%q) = %q, want %q", tt.quality, got, tt.wantNorm)
}
})
}
}
func TestParseQualityValue(t *testing.T) {
t.Parallel()
tests := []struct {
name string
quality string
want int
}{
{
name: "auto returns 240",
quality: "auto",
want: 240,
},
{
name: "1080p extracts 1080",
quality: "1080p",
want: 1080,
},
{
name: "720 extracts 720",
quality: "720",
want: 720,
},
{
name: "fhd is treated as fhd",
quality: "fhd",
want: 0,
},
{
name: "empty returns 0",
quality: "",
want: 0,
},
{
name: "multiple digits stops at first non-digit",
quality: "1080p60fps",
want: 1080,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := parseQualityValue(tt.quality)
if got != tt.want {
t.Errorf("parseQualityValue(%q) = %d, want %d", tt.quality, got, tt.want)
}
})
}
}
func TestQualityMatches(t *testing.T) {
t.Parallel()
tests := []struct {
name string
source string
target string
want bool
}{
{
name: "exact match",
source: "1080p",
target: "1080p",
want: true,
},
{
name: "target in source",
source: "1920x1080",
target: "1080",
want: true,
},
{
name: "digit match",
source: "1080p",
target: "1080",
want: true,
},
{
name: "no match",
source: "720p",
target: "1080",
want: false,
},
{
name: "empty source returns false",
source: "",
target: "1080",
want: false,
},
{
name: "empty target returns true (empty always contained)",
source: "1080p",
target: "",
want: true,
},
{
name: "auto doesn't match specific",
source: "auto",
target: "1080",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := qualityMatches(tt.source, tt.target)
if got != tt.want {
t.Errorf("qualityMatches(%q, %q) = %v, want %v", tt.source, tt.target, got, tt.want)
}
})
}
}
func TestSourceQualityPriority(t *testing.T) {
t.Parallel()
tests := []struct {
name string
source string
target string
wantMin int
}{
{
name: "best mode favors higher quality",
source: "1080p",
target: "best",
wantMin: 1080,
},
{
name: "worst mode penalizes higher quality",
source: "1080p",
target: "worst",
wantMin: -2000,
},
{
name: "exact match gets bonus",
source: "1080p",
target: "1080p",
wantMin: 2000,
},
{
name: "close match gets penalty but positive score",
source: "1080p",
target: "720p",
wantMin: 500,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := sourceQualityPriority(tt.source, tt.target)
if tt.wantMin != 0 && got < tt.wantMin {
t.Errorf("sourceQualityPriority(%q, %q) = %d, want >= %d", tt.source, tt.target, got, tt.wantMin)
}
})
}
}
func TestSourceTypePriorityLookup(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sourceType string
want int
}{
{
name: "mp4 priority",
sourceType: "mp4",
want: 500,
},
{
name: "m3u8 priority",
sourceType: "m3u8",
want: 450,
},
{
name: "unknown uses fallback",
sourceType: "unknown",
want: 300,
},
{
name: "embed fallback",
sourceType: "embed",
want: 100,
},
{
name: "unrecognized uses fallback",
sourceType: "video",
want: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := lookupPriority(sourceTypePriority, tt.sourceType, 200)
if got != tt.want {
t.Errorf("lookupPriority(sourceTypePriority, %q, 200) = %d, want %d", tt.sourceType, got, tt.want)
}
})
}
}
func TestProviderPriorityLookup(t *testing.T) {
t.Parallel()
tests := []struct {
name string
provider string
want int
}{
{
name: "s-mp4",
provider: "s-mp4",
want: 120,
},
{
name: "default",
provider: "default",
want: 115,
},
{
name: "yt-mp4",
provider: "yt-mp4",
want: 100,
},
{
name: "unknown uses fallback",
provider: "unknown",
want: 60,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := lookupPriority(providerPriority, tt.provider, 60)
if got != tt.want {
t.Errorf("lookupPriority(providerPriority, %q, 60) = %d, want %d", tt.provider, got, tt.want)
}
})
}
}
func TestNormalizeSourceTypeFromProbe(t *testing.T) {
t.Parallel()
tests := []struct {
name string
source StreamSource
contentType string
wantType string
}{
{
name: "video/mp4 normalizes to mp4",
source: StreamSource{Type: "unknown"},
contentType: "video/mp4",
wantType: "mp4",
},
{
name: "application/octet-stream unchanged",
source: StreamSource{Type: "mp4"},
contentType: "application/octet-stream",
wantType: "mp4",
},
{
name: "mpegurl normalizes to m3u8",
source: StreamSource{Type: "unknown"},
contentType: "application/vnd.apple.mpegurl",
wantType: "m3u8",
},
{
name: "video/mpegurl",
source: StreamSource{Type: "unknown"},
contentType: "video/mpegurl",
wantType: "m3u8",
},
{
name: "case insensitive",
source: StreamSource{Type: "unknown"},
contentType: "VIDEO/MP4",
wantType: "mp4",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := normalizeSourceTypeFromProbe(tt.source, tt.contentType)
if got.Type != tt.wantType {
t.Errorf("normalizeSourceTypeFromProbe().Type = %q, want %q", got.Type, tt.wantType)
}
})
}
}
func TestExtractDigits(t *testing.T) {
t.Parallel()
tests := []struct {
name string
value string
want string
}{
{
name: "extracts digits",
value: "1080p",
want: "1080",
},
{
name: "empty if no digits",
value: "p",
want: "",
},
{
name: "stops at non-digit after digits",
value: "720p60",
want: "720",
},
{
name: "multiple non-digit does not break",
value: "abc123def",
want: "123",
},
{
name: "all digits",
value: "1080",
want: "1080",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := extractDigits(tt.value)
if got != tt.want {
t.Errorf("extractDigits(%q) = %q, want %q", tt.value, got, tt.want)
}
})
}
}

View File

@@ -1,171 +0,0 @@
package playback
import (
"context"
"errors"
"sort"
"strconv"
"strings"
"sync"
)
func (s *Service) resolveShow(ctx context.Context, malID int, titleCandidates []string) (string, string, error) {
malText := strconv.Itoa(malID)
modeCandidates := []string{"sub", "dub"}
queries := buildTitleSearchQueries(titleCandidates)
for _, query := range queries {
resultsByMode := s.searchShowResultsByMode(ctx, query, modeCandidates)
for _, mode := range modeCandidates {
for _, result := range resultsByMode[mode] {
// exact mal id match
if strings.TrimSpace(result.MalID) == malText && strings.TrimSpace(result.ID) != "" {
return result.ID, result.Name, nil
}
}
}
for _, mode := range modeCandidates {
results := resultsByMode[mode]
if len(results) == 0 {
continue
}
// fallback to first result if no exact match
best := results[0]
if strings.TrimSpace(best.ID) != "" {
return best.ID, best.Name, nil
}
}
}
return "", "", errors.New("unable to resolve allanime show")
}
func (s *Service) searchShowResultsByMode(ctx context.Context, query string, modeCandidates []string) map[string][]searchResult {
resultsByMode := make(map[string][]searchResult, len(modeCandidates))
searchCh := make(chan searchModeResult, len(modeCandidates))
var wg sync.WaitGroup
for _, mode := range modeCandidates {
modeValue := mode // capture loop variable
wg.Go(func() {
results, err := s.allAnimeClient.Search(ctx, query, modeValue)
searchCh <- searchModeResult{Mode: modeValue, Results: results, Err: err}
})
}
wg.Wait()
close(searchCh)
for result := range searchCh {
if result.Err != nil {
continue
}
resultsByMode[result.Mode] = result.Results
}
return resultsByMode
}
func buildTitleSearchQueries(titleCandidates []string) []string {
queries := make([]string, 0, len(titleCandidates)*4)
seen := make(map[string]struct{})
add := func(raw string) {
normalized := normalizeSearchQuery(raw)
if normalized == "" {
return
}
key := strings.ToLower(normalized)
if _, exists := seen[key]; exists {
return
}
seen[key] = struct{}{}
queries = append(queries, normalized)
}
for _, candidate := range titleCandidates {
normalized := normalizeSearchQuery(candidate)
if normalized == "" {
continue
}
add(normalized)
add(strings.ReplaceAll(normalized, "+", " "))
// strip apostrophes to improve match rate
withoutApostrophes := strings.NewReplacer("'", "", "", "", "`", "").Replace(normalized)
add(withoutApostrophes)
add(strings.ReplaceAll(withoutApostrophes, "+", " "))
}
return queries
}
func normalizeSearchQuery(raw string) string {
return strings.Join(strings.Fields(strings.TrimSpace(raw)), " ")
}
func firstNonEmptyTitle(values []string) string {
for _, value := range values {
normalized := strings.TrimSpace(value)
if normalized != "" {
return normalized
}
}
return ""
}
func normalizeMode(raw string) string {
return strings.ToLower(strings.TrimSpace(raw))
}
func availableModes(modeSources map[string]ModeSource) []string {
preferred := []string{"dub", "sub"}
ordered := make([]string, 0, len(modeSources))
for _, mode := range preferred {
if _, ok := modeSources[mode]; ok {
ordered = append(ordered, mode)
}
}
extra := make([]string, 0)
for mode := range modeSources {
if mode == "dub" || mode == "sub" {
continue
}
extra = append(extra, mode)
}
sort.Strings(extra)
return append(ordered, extra...)
}
// selectInitialMode picks a mode prioritizing: requested mode > dub > sub > first available.
func selectInitialMode(requestedMode string, modeSources map[string]ModeSource) string {
normalizedRequested := normalizeMode(requestedMode)
if normalizedRequested != "" {
if _, ok := modeSources[normalizedRequested]; ok {
return normalizedRequested
}
}
if _, ok := modeSources["dub"]; ok {
return "dub"
}
if _, ok := modeSources["sub"]; ok {
return "sub"
}
for mode := range modeSources {
return mode
}
return "dub"
}

View File

@@ -1,224 +0,0 @@
package playback
import (
"context"
"errors"
"io"
"net/http"
"strings"
"sync"
)
// resolveModeSourceWithCache is like resolveModeSource but caches probe results.
func (s *Service) resolveModeSourceWithCache(
ctx context.Context,
showID string,
episode string,
mode string,
quality string,
probeCache map[string]directProbeResult,
probeCacheMu *sync.Mutex,
) (StreamSource, error) {
sources, err := s.allAnimeClient.GetEpisodeSources(ctx, showID, episode, mode)
if err != nil {
return StreamSource{}, err
}
ranked, err := rankSources(sources, quality)
if err != nil {
return StreamSource{}, err
}
selected, _, err := s.choosePlaybackSourceWithCache(ctx, ranked, probeCache, probeCacheMu)
if err != nil {
return StreamSource{}, err
}
selected.AvailableQualities = sources
return selected, nil
}
// choosePlaybackSource selects the best playable source from ranked candidates.
// priority: direct media > probed media > embed sources > ranked fallback.
func (s *Service) choosePlaybackSource(
ctx context.Context,
ranked []sourceScore,
probeFn func(context.Context, StreamSource) (bool, string),
) (StreamSource, string, error) {
if len(ranked) == 0 {
return StreamSource{}, "", errors.New("no ranked sources available")
}
embedCandidates := make([]StreamSource, 0, len(ranked))
for _, candidate := range ranked {
source := candidate.source
switch strings.ToLower(source.Type) {
case "mp4", "m3u8":
return source, "direct-media", nil // known playable types
case "embed":
embedCandidates = append(embedCandidates, source) // need probing
default:
// probe unknown types
if playable, contentType := probeFn(ctx, source); playable {
return normalizeSourceTypeFromProbe(source, contentType), "probed-media", nil
}
}
}
// check embed sources for playability
for _, embed := range embedCandidates {
if s.probeEmbedSource(ctx, embed) {
return embed, "embed-probed", nil
}
}
// fallback to first embed or first ranked
if len(embedCandidates) > 0 {
return embedCandidates[0], "embed-fallback", nil
}
return ranked[0].source, "ranked-fallback", nil
}
// choosePlaybackSourceWithCache wraps choosePlaybackSource with cached probing.
func (s *Service) choosePlaybackSourceWithCache(
ctx context.Context,
ranked []sourceScore,
probeCache map[string]directProbeResult,
probeCacheMu *sync.Mutex,
) (StreamSource, string, error) {
return s.choosePlaybackSource(ctx, ranked, func(ctx context.Context, source StreamSource) (bool, string) {
return s.probeDirectMediaCached(ctx, source, probeCache, probeCacheMu)
})
}
func (s *Service) probeDirectMediaCached(
ctx context.Context,
source StreamSource,
probeCache map[string]directProbeResult,
probeCacheMu *sync.Mutex,
) (bool, string) {
cacheKey := strings.TrimSpace(source.URL)
if cacheKey == "" {
return s.probeDirectMedia(ctx, source)
}
probeCacheMu.Lock()
cached, ok := probeCache[cacheKey]
probeCacheMu.Unlock()
if ok {
return cached.Playable, cached.ContentType
}
playable, contentType := s.probeDirectMedia(ctx, source)
probeCacheMu.Lock()
probeCache[cacheKey] = directProbeResult{Playable: playable, ContentType: contentType}
probeCacheMu.Unlock()
return playable, contentType
}
// probeDirectMedia checks if a direct media URL is playable.
// checks content-type header, reads prefix for magic bytes, falls back to URL extension.
func (s *Service) probeDirectMedia(ctx context.Context, source StreamSource) (bool, string) {
probeCtx, cancel := context.WithTimeout(ctx, providerProbeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(probeCtx, http.MethodGet, source.URL, nil)
if err != nil {
return false, ""
}
if source.Referer != "" {
req.Header.Set("Referer", source.Referer)
}
req.Header.Set("User-Agent", defaultUserAgent)
req.Header.Set("Range", "bytes=0-4095") // small range to detect playable content
resp, err := s.httpClient.Do(req)
if err != nil {
return false, ""
}
defer func() { _ = resp.Body.Close() }()
// check content-type header first
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if strings.Contains(contentType, "video/") || strings.Contains(contentType, "mpegurl") {
return true, contentType
}
// check magic bytes in prefix
prefix, err := io.ReadAll(io.LimitReader(resp.Body, 4096))
if err == nil {
if isLikelyM3U8(prefix) {
return true, "application/vnd.apple.mpegurl"
}
if isLikelyMP4(prefix) {
return true, "video/mp4"
}
}
// fallback to URL extension
finalURL := ""
if resp.Request != nil && resp.Request.URL != nil {
finalURL = strings.ToLower(resp.Request.URL.String())
}
if strings.Contains(finalURL, ".mp4") || strings.Contains(finalURL, ".m3u8") {
return true, contentType
}
return false, contentType
}
// probeEmbedSource checks if an embed page is still available.
// returns false if the page contains deletion markers.
func (s *Service) probeEmbedSource(ctx context.Context, source StreamSource) bool {
ctx, cancel := context.WithTimeout(ctx, providerProbeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, source.URL, nil)
if err != nil {
return false
}
if source.Referer != "" {
req.Header.Set("Referer", source.Referer)
}
req.Header.Set("User-Agent", defaultUserAgent)
resp, err := s.httpClient.Do(req)
if err != nil {
return false
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= http.StatusBadRequest {
return false
}
// check for common deletion messages
body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024))
if err != nil {
return false
}
content := strings.ToLower(string(body))
for _, marker := range []string{
"file was deleted",
"file has been deleted",
"video was deleted",
"video has been deleted",
"video unavailable",
"file not found",
"this file does not exist",
"resource unavailable",
} {
if strings.Contains(content, marker) {
return false
}
}
return true
}

View File

@@ -1,24 +0,0 @@
package playback
import (
"strings"
)
// toSubtitleItems converts raw subtitle entries into client-safe items.
func toSubtitleItems(source StreamSource) []SubtitleItem {
items := make([]SubtitleItem, 0, len(source.Subtitles))
for _, subtitle := range source.Subtitles {
targetURL := strings.TrimSpace(subtitle.URL)
if targetURL == "" {
continue
}
items = append(items, SubtitleItem{
Lang: strings.TrimSpace(subtitle.Lang),
URL: targetURL,
Referer: source.Referer,
})
}
return items
}

View File

@@ -1,52 +0,0 @@
package playback
// StreamSource represents a video stream from a provider.
type StreamSource struct {
URL string
Quality string
Provider string
Type string // m3u8, mp4, embed, unknown
Referer string
Subtitles []Subtitle
AvailableQualities []StreamSource
}
type Subtitle struct {
Lang string
URL string
}
type ModeSource struct {
URL string `json:"url,omitempty"`
Referer string `json:"referer,omitempty"`
Token string `json:"token"`
Subtitles []SubtitleItem `json:"subtitles"`
Qualities []string `json:"qualities,omitempty"`
}
type SubtitleItem struct {
Lang string `json:"lang"`
URL string `json:"url,omitempty"`
Referer string `json:"referer,omitempty"`
Token string `json:"token"`
}
type SkipSegment struct {
Type string `json:"type"`
Start float64 `json:"start"`
End float64 `json:"end"`
}
// WatchPageData is the response payload for the watch page frontend.
type WatchPageData struct {
MalID int
Title string
CurrentEpisode string
StartTimeSeconds float64
CurrentStatus string
InitialMode string
AvailableModes []string
ModeSources map[string]ModeSource
Segments []SkipSegment
FallbackEpisodes map[string]int
}

View File

@@ -1,174 +0,0 @@
package watchlist
import (
"context"
"encoding/json"
"errors"
"log"
"net/http"
"strconv"
"mal/internal/db"
"mal/internal/middleware"
"mal/templates"
)
type Handler struct {
service *Service
}
func NewHandler(service *Service) *Handler {
return &Handler{service: service}
}
// HandleUpdateWatchlist adds or updates anime in user's watchlist. accepts json {animeId, status}.
func (h *Handler) HandleUpdateWatchlist(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
user := middleware.GetUser(r.Context())
if user == nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var body struct {
AnimeID int64 `json:"animeId"`
Status string `json:"status"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
// default status if not provided
if body.Status == "" {
body.Status = "plan_to_watch"
}
if err := h.service.AddToWatchlist(r.Context(), user.ID, body.AnimeID, body.Status); err != nil {
log.Printf("failed to add to watchlist: %v", err)
http.Error(w, "failed to add to watchlist", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
// HandleDeleteWatchlist removes anime from user's watchlist. expects /api/watchlist/{animeId}.
func (h *Handler) HandleDeleteWatchlist(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r.Context())
if user == nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
animeIDStr := r.URL.Path[len("/api/watchlist/"):]
animeID, err := strconv.ParseInt(animeIDStr, 10, 64)
if err != nil {
http.Error(w, "invalid anime id", http.StatusBadRequest)
return
}
if _, err := h.service.RemoveEntry(r.Context(), user.ID, animeID); err != nil {
log.Printf("failed to remove from watchlist: %v", err)
http.Error(w, "failed to remove from watchlist", http.StatusInternalServerError)
return
}
// htmx: redirect to watchlist page after delete
w.Header().Set("HX-Redirect", "/watchlist")
w.WriteHeader(http.StatusOK)
}
// HandleDeleteContinueWatching removes entry from user's continue watching. expects /api/continue-watching/{animeId}.
func (h *Handler) HandleDeleteContinueWatching(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r.Context())
if user == nil {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
animeIDStr := r.URL.Path[len("/api/continue-watching/"):]
animeID, err := strconv.ParseInt(animeIDStr, 10, 64)
if err != nil {
http.Error(w, "invalid anime id", http.StatusBadRequest)
return
}
if err := h.service.DeleteContinueWatching(r.Context(), user.ID, animeID); err != nil {
log.Printf("failed to remove from continue watching: %v", err)
http.Error(w, "failed to remove from continue watching", http.StatusInternalServerError)
return
}
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.WriteHeader(http.StatusOK)
}
// HandleGetWatchlist renders user's watchlist page, grouped by status.
func (h *Handler) HandleGetWatchlist(w http.ResponseWriter, r *http.Request) {
user := middleware.GetUser(r.Context())
if user == nil {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
entries, err := h.service.GetUserWatchlist(r.Context(), user.ID)
if err != nil {
log.Printf("failed to fetch watchlist: %v", err)
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, "not_found.gohtml", map[string]any{
"CurrentPath": r.URL.Path,
}); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
return
}
// group entries by status for display
watchlistByStatus := make(map[string][]db.GetUserWatchListRow)
allEntries := make([]db.GetUserWatchListRow, 0)
watchlistIDs := make([]int64, len(entries))
for i, entry := range entries {
status := entry.Status
if status == "" {
status = "plan_to_watch"
}
watchlistByStatus[status] = append(watchlistByStatus[status], entry)
allEntries = append(allEntries, entry)
watchlistIDs[i] = entry.AnimeID
}
data := map[string]any{
"User": user,
"CurrentPath": r.URL.Path,
"WatchlistByStatus": watchlistByStatus,
"AllEntries": allEntries,
"WatchlistIDs": watchlistIDs,
"StatusOrder": []string{"watching", "plan_to_watch", "on_hold", "completed", "dropped"},
"StatusLabels": map[string]string{
"watching": "Currently Watching",
"plan_to_watch": "Plan to Watch",
"on_hold": "On Hold",
"completed": "Completed",
"dropped": "Dropped",
},
}
// use partial template for htmx requests
templateName := "watchlist.gohtml"
if r.Header.Get("HX-Request") == "true" {
templateName = "watchlist_partial.gohtml"
}
if err := templates.GetRenderer().ExecuteTemplate(r.Context(), w, templateName, data); err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("render error: %v", err)
}
}
}

View File

@@ -1,184 +0,0 @@
package watchlist
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"github.com/google/uuid"
"mal/integrations/jikan"
"mal/internal/db"
)
type Service struct {
db db.Querier
sqlDB *sql.DB
jikanClient *jikan.Client
}
var (
ErrInvalidAnimeID = errors.New("invalid anime ID")
ErrInvalidStatus = errors.New("invalid watchlist status")
)
var validStatuses = map[string]struct{}{
"watching": {},
"completed": {},
"dropped": {},
"plan_to_watch": {},
"on_hold": {},
}
func NewService(db db.Querier, sqlDB *sql.DB, jikanClient *jikan.Client) *Service {
return &Service{db: db, sqlDB: sqlDB, jikanClient: jikanClient}
}
// ensureAnimeExists checks if anime exists in db, fetches from jikan if not, then upserts.
func (s *Service) ensureAnimeExists(ctx context.Context, animeID int64) error {
_, err := s.db.GetAnime(ctx, animeID)
if err == nil {
return nil // already exists
}
// fetch from jikan and store locally
anime, err := s.jikanClient.GetAnimeByID(ctx, int(animeID))
if err != nil {
return fmt.Errorf("failed to fetch anime from jikan: %w", err)
}
_, err = s.db.UpsertAnime(ctx, db.UpsertAnimeParams{
ID: int64(anime.MalID),
TitleOriginal: anime.Title,
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""},
ImageUrl: anime.Images.Jpg.LargeImageURL,
Airing: sql.NullBool{Bool: anime.Airing, Valid: true},
})
if err != nil {
return fmt.Errorf("failed to save anime: %w", err)
}
return nil
}
type AddRequest struct {
AnimeID int64
TitleOriginal string
TitleEnglish string
TitleJapanese string
ImageURL string
Status string
Airing bool
}
// AddToWatchlist adds or updates an anime entry in user's watchlist.
func (s *Service) AddToWatchlist(ctx context.Context, userID string, animeID int64, status string) error {
if animeID <= 0 {
return ErrInvalidAnimeID
}
if _, ok := validStatuses[status]; !ok {
return ErrInvalidStatus
}
// ensure anime exists in local db before linking
if err := s.ensureAnimeExists(ctx, animeID); err != nil {
return err
}
entryID := uuid.New().String()
_, err := s.db.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{
ID: entryID,
UserID: userID,
AnimeID: animeID,
Status: status,
CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
})
if err != nil {
return fmt.Errorf("failed to update watchlist: %w", err)
}
return nil
}
// RemoveEntry deletes a watchlist entry and returns the anime for potential use.
func (s *Service) RemoveEntry(ctx context.Context, userID string, animeID int64) (db.Anime, error) {
if animeID <= 0 {
return db.Anime{}, ErrInvalidAnimeID
}
anime, err := s.db.GetAnime(ctx, animeID)
if err != nil {
return db.Anime{}, fmt.Errorf("anime not found: %w", err)
}
err = s.db.DeleteWatchListEntry(ctx, db.DeleteWatchListEntryParams{
UserID: userID,
AnimeID: animeID,
})
if err != nil {
return db.Anime{}, fmt.Errorf("failed to delete from watchlist: %w", err)
}
return anime, nil
}
// GetUserWatchlist retrieves all watchlist entries for a user.
func (s *Service) GetUserWatchlist(ctx context.Context, userID string) ([]db.GetUserWatchListRow, error) {
entries, err := s.db.GetUserWatchList(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to fetch watchlist: %w", err)
}
return entries, nil
}
// DeleteContinueWatching removes entry and clears associated watch progress.
// uses transaction when sqlDB is available.
func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error {
if strings.TrimSpace(userID) == "" {
return errors.New("invalid user id")
}
if animeID <= 0 {
return ErrInvalidAnimeID
}
params := db.DeleteContinueWatchingEntryParams{
UserID: userID,
AnimeID: animeID,
}
clearProgress := db.SaveWatchProgressParams{
CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
UserID: userID,
AnimeID: animeID,
}
// use transaction when sqlDB available for consistency
if s.sqlDB == nil {
if err := s.db.DeleteContinueWatchingEntry(ctx, params); err != nil {
return fmt.Errorf("failed to delete continue watching entry: %w", err)
}
return s.db.SaveWatchProgress(ctx, clearProgress)
}
txQueries, tx, err := db.BeginTx(ctx, s.sqlDB)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := txQueries.DeleteContinueWatchingEntry(ctx, params); err != nil {
return fmt.Errorf("failed to delete continue watching entry: %w", err)
}
if err := txQueries.SaveWatchProgress(ctx, clearProgress); err != nil {
return fmt.Errorf("failed to clear watchlist progress: %w", err)
}
return tx.Commit()
}

View File

@@ -1,72 +0,0 @@
package watchlist
import (
"context"
"testing"
"mal/internal/db"
)
type fakeQuerier struct {
db.Querier
upsertAnimeCalled bool
upsertEntryCalled bool
upsertEntryParams db.UpsertWatchListEntryParams
getAnimeFunc func(ctx context.Context, id int64) (db.Anime, error)
}
func (f *fakeQuerier) GetAnime(ctx context.Context, id int64) (db.Anime, error) {
if f.getAnimeFunc != nil {
return f.getAnimeFunc(ctx, id)
}
return db.Anime{}, nil
}
func (f *fakeQuerier) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) {
f.upsertAnimeCalled = true
return db.Anime{}, nil
}
func (f *fakeQuerier) UpsertWatchListEntry(ctx context.Context, arg db.UpsertWatchListEntryParams) (db.WatchListEntry, error) {
f.upsertEntryCalled = true
f.upsertEntryParams = arg
return db.WatchListEntry{}, nil
}
func (f *fakeQuerier) GetUserWatchList(ctx context.Context, userID string) ([]db.GetUserWatchListRow, error) {
return nil, nil
}
func TestAddEntry_RejectsInvalidAnimeID(t *testing.T) {
t.Parallel()
q := &fakeQuerier{}
svc := NewService(q, nil, nil)
err := svc.AddToWatchlist(context.Background(), "user-1", 0, "watching")
if err != ErrInvalidAnimeID {
t.Fatalf("expected ErrInvalidAnimeID, got %v", err)
}
if q.upsertAnimeCalled || q.upsertEntryCalled {
t.Fatal("expected no database writes for invalid anime id")
}
}
func TestAddEntry_RejectsInvalidStatus(t *testing.T) {
t.Parallel()
q := &fakeQuerier{}
svc := NewService(q, nil, nil)
err := svc.AddToWatchlist(context.Background(), "user-1", 1, "invalid")
if err != ErrInvalidStatus {
t.Fatalf("expected ErrInvalidStatus, got %v", err)
}
if q.upsertAnimeCalled || q.upsertEntryCalled {
t.Fatal("expected no database writes for invalid status")
}
}

View File

@@ -1,118 +0,0 @@
package db
import (
"database/sql"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"
)
// RunMigrations applies all *.sql files in migrationsDir in sorted order,
// skipping any already recorded in migration_version.
func RunMigrations(db *sql.DB, migrationsDir string) error {
if migrationsDir == "" {
return fmt.Errorf("migrations directory is required")
}
// Create migration tracking table
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS migration_version (
name TEXT PRIMARY KEY,
applied_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return err
}
migrations, err := filepath.Glob(filepath.Join(migrationsDir, "*.sql"))
if err != nil {
return err
}
if len(migrations) == 0 {
return fmt.Errorf("no migration files found in %s", migrationsDir)
}
sort.Strings(migrations)
appliedNames, err := loadAppliedMigrationNames(db)
if err != nil {
return err
}
for _, migrationFile := range migrations {
migrationName := filepath.Base(migrationFile)
if migrationApplied(appliedNames, migrationName) {
continue // already applied
}
migrationSQL, err := os.ReadFile(migrationFile)
if err != nil {
return err
}
if _, err := db.Exec(string(migrationSQL)); err != nil {
return err // stop on first failure
}
// record applied migration
_, err = db.Exec("INSERT INTO migration_version (name) VALUES (?)", migrationName)
if err != nil {
return err
}
appliedNames[migrationName] = struct{}{}
log.Printf("migration %s applied successfully", migrationName)
}
return nil
}
func loadAppliedMigrationNames(db *sql.DB) (map[string]struct{}, error) {
rows, err := db.Query("SELECT name FROM migration_version")
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
applied := make(map[string]struct{})
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
applied[name] = struct{}{}
}
if err := rows.Err(); err != nil {
return nil, err
}
return applied, nil
}
// migrationApplied checks the applied names map for a match,
// including legacy paths and case-insensitive basename matches.
func migrationApplied(appliedNames map[string]struct{}, migrationName string) bool {
if _, exists := appliedNames[migrationName]; exists {
return true
}
legacyName := filepath.ToSlash(filepath.Join("migrations", migrationName))
if _, exists := appliedNames[legacyName]; exists {
return true
}
for appliedName := range appliedNames {
if strings.EqualFold(filepath.Base(appliedName), migrationName) {
return true
}
}
return false
}

View File

@@ -1,162 +0,0 @@
package server
import (
"database/sql"
"fmt"
"net/http"
"path/filepath"
"strings"
"time"
"mal/api/anime"
"mal/api/auth"
"mal/api/playback"
"mal/api/watchlist"
"mal/integrations/jikan"
"mal/internal/db"
"mal/internal/middleware"
pkgmiddleware "mal/pkg/middleware"
)
type Config struct {
DB *db.Queries
SQLDB *sql.DB
JikanClient *jikan.Client
AuthService *auth.Service
AuthLimiter *pkgmiddleware.Limiter
PlaybackProxySecret string
}
// withMimeTypes sets Content-Type for common static asset extensions
func withMimeTypes(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ext := strings.ToLower(filepath.Ext(r.URL.Path))
switch ext {
case ".js":
w.Header().Set("Content-Type", "application/javascript")
case ".css":
w.Header().Set("Content-Type", "text/css")
case ".svg":
w.Header().Set("Content-Type", "image/svg+xml")
case ".json":
w.Header().Set("Content-Type", "application/json")
}
next.ServeHTTP(w, r)
})
}
// noCache sends headers to prevent caching of dynamic/static assets
func noCache(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate")
w.Header().Set("Pragma", "no-cache")
w.Header().Set("Expires", "0")
next.ServeHTTP(w, r)
})
}
// NewAuthLimiter returns a rate limiter for auth endpoints: 5 attempts per minute
func NewAuthLimiter() *pkgmiddleware.Limiter {
return pkgmiddleware.NewLimiter(pkgmiddleware.Config{
MaxAttempts: 5,
Window: time.Minute,
})
}
// NewRouter wires up all HTTP handlers and middleware.
// Auth is enforced globally; public routes must opt-out via middleware policy.
func NewRouter(cfg Config) http.Handler {
mux := http.NewServeMux()
authHandler := auth.NewHandler(cfg.AuthService)
watchlistSvc := watchlist.NewService(cfg.DB, cfg.SQLDB, cfg.JikanClient)
watchlistHandler := watchlist.NewHandler(watchlistSvc)
animeSvc := anime.NewService(cfg.JikanClient, cfg.DB)
animeHandler := anime.NewHandler(animeSvc)
playbackSvc, err := playback.NewService(cfg.DB, cfg.SQLDB, playback.Config{
ProxyTokenSecret: cfg.PlaybackProxySecret,
})
if err != nil {
panic(fmt.Sprintf("failed to initialize playback service: %v", err))
}
playbackHandler := playback.NewHandler(playbackSvc, cfg.JikanClient)
// Serve static files with no-cache headers
fs := noCache(http.FileServer(http.Dir("./static")))
mux.Handle("/static/", http.StripPrefix("/static/", fs))
// Serve built frontend assets with no-cache headers
dist := noCache(http.FileServer(http.Dir("./dist")))
mux.Handle("/dist/", http.StripPrefix("/dist/", withMimeTypes(dist)))
// Serve Apple Touch Icons from static directory
mux.HandleFunc("/apple-touch-icon.png", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/svg+xml")
http.ServeFile(w, r, "./static/apple-touch-icon.svg")
})
mux.HandleFunc("/apple-touch-icon-precomposed.png", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/svg+xml")
http.ServeFile(w, r, "./static/apple-touch-icon-precomposed.svg")
})
mux.HandleFunc("/apple-touch-icon-120x120.png", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/svg+xml")
http.ServeFile(w, r, "./static/apple-touch-icon-120x120.svg")
})
mux.HandleFunc("/apple-touch-icon-120x120-precomposed.png", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "image/svg+xml")
http.ServeFile(w, r, "./static/apple-touch-icon-120x120-precomposed.svg")
})
mux.HandleFunc("/", animeHandler.HandleCatalog)
mux.HandleFunc("/api/catalog/airing", animeHandler.HandleCatalogAiring)
mux.HandleFunc("/api/catalog/popular", animeHandler.HandleCatalogPopular)
mux.HandleFunc("/api/catalog/continue", animeHandler.HandleCatalogContinue)
mux.HandleFunc("/search", animeHandler.HandleSearch)
mux.HandleFunc("/browse", animeHandler.HandleBrowse)
mux.HandleFunc("/discover", animeHandler.HandleDiscover)
mux.HandleFunc("/api/discover/trending", animeHandler.HandleDiscoverTrending)
mux.HandleFunc("/api/discover/upcoming", animeHandler.HandleDiscoverUpcoming)
mux.HandleFunc("/api/discover/top", animeHandler.HandleDiscoverTop)
mux.HandleFunc("/api/search-quick", animeHandler.HandleQuickSearch)
mux.HandleFunc("/api/jikan/random/anime", animeHandler.HandleRandomAnime)
mux.HandleFunc("/anime/", func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/watch") {
playbackHandler.HandleWatchPage(w, r)
return
}
animeHandler.HandleAnimeDetails(w, r)
})
mux.HandleFunc("/api/watch-order", animeHandler.HandleHTMLWatchOrder)
mux.HandleFunc("/watch/", playbackHandler.HandleWatchPage)
mux.HandleFunc("/watch/proxy/stream", playbackHandler.HandleProxy)
mux.HandleFunc("/watch/proxy/segment", playbackHandler.HandleProxy)
mux.HandleFunc("/watch/proxy/subtitle", playbackHandler.HandleProxy)
mux.HandleFunc("/api/watch-progress", playbackHandler.HandleSaveProgress)
mux.HandleFunc("/api/watch-complete", playbackHandler.HandleCompleteAnime)
mux.HandleFunc("/api/watch/episode/", playbackHandler.HandleEpisodeData)
mux.HandleFunc("/api/watch/thumbnails/", playbackHandler.HandleEpisodeThumbnails)
// Auth Endpoints
mux.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
authHandler.HandleLoginPage(w, r)
} else {
cfg.AuthLimiter.AuthMiddleware(pkgmiddleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogin))).ServeHTTP(w, r)
}
})
mux.HandleFunc("/logout", authHandler.HandleLogout)
// Watchlist Endpoints
mux.HandleFunc("/api/watchlist", watchlistHandler.HandleUpdateWatchlist)
mux.HandleFunc("/api/watchlist/", watchlistHandler.HandleDeleteWatchlist)
mux.HandleFunc("/api/continue-watching/", watchlistHandler.HandleDeleteContinueWatching)
mux.HandleFunc("/watchlist", watchlistHandler.HandleGetWatchlist)
// Wrap mux with global CSRF origin verification and auth checking
protectedHandler := middleware.RequireGlobalAuthWithPolicy(middleware.NewAccessPolicy())(pkgmiddleware.VerifyOrigin(mux))
authenticatedHandler := middleware.Auth(cfg.AuthService)(protectedHandler)
return pkgmiddleware.RequestLogger(authenticatedHandler)
}

View File

@@ -1,198 +0,0 @@
package templates
import (
"context"
"encoding/json"
"fmt"
"html/template"
"io"
"log"
"path/filepath"
"slices"
"strconv"
"strings"
"sync"
)
var (
once sync.Once
renderer *Renderer
)
type Renderer struct {
templates map[string]*template.Template
}
// GetRenderer returns the singleton renderer, initializing it on first call.
// Templates are loaded from ./templates/*.gohtml and ./templates/components/*.gohtml.
func GetRenderer() *Renderer {
once.Do(func() {
renderer = &Renderer{
templates: make(map[string]*template.Template),
}
funcs := template.FuncMap{
"dict": func(values ...any) map[string]any {
m := make(map[string]any)
for i := 0; i < len(values)-1; i += 2 {
key, ok := values[i].(string)
if !ok {
continue
}
m[key] = values[i+1]
}
return m
},
"json": func(v any) template.HTMLAttr {
b, _ := json.Marshal(v)
return template.HTMLAttr(b)
},
"genresParams": func(genres []int) string {
if len(genres) == 0 {
return ""
}
var s strings.Builder
for _, g := range genres {
s.WriteString("genres=" + fmt.Sprintf("%d", g) + "&")
}
return s.String()[:len(s.String())-1]
},
"hasGenre": func(id int, genres []int) bool {
return slices.Contains(genres, id)
},
"add": func(a, b int) int {
return a + b
},
"sub": func(a, b int) int {
return a - b
},
"mul": func(a, b float64) float64 {
return a * b
},
"imul": func(a, b int) int {
return a * b
},
"div": func(a, b float64) float64 {
if b == 0 {
return 0
}
return a / b
},
"ceilDiv": func(a, b int) int {
if b == 0 {
return 0
}
return (a + b - 1) / b
},
"toFloat": func(a int) float64 {
return float64(a)
},
"seq": func(v any) []int {
var count int
switch n := v.(type) {
case int:
count = n
case int64:
count = int(n)
default:
count = 0
}
res := make([]int, count)
for i := 0; i < count; i++ {
res[i] = i
}
return res
},
"min": func(a, b int) int {
if a < b {
return a
}
return b
},
"int": func(v any) int {
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
case string:
i, _ := strconv.Atoi(n)
return i
default:
return 0
}
},
"percent": func(current, total float64) float64 {
if total == 0 {
return 0
}
return (current / total) * 100
},
}
pages, err := filepath.Glob(filepath.Join(".", "templates", "*.gohtml"))
if err != nil {
log.Fatalf("failed to glob page templates: %v", err)
}
components, err := filepath.Glob(filepath.Join(".", "templates", "components", "*.gohtml"))
if err != nil {
log.Fatalf("failed to glob component templates: %v", err)
}
for _, page := range pages {
name := filepath.Base(page)
if name == "base.gohtml" {
continue
}
tmpl := template.New(name).Funcs(funcs)
// Parse base first so it establishes the core definitions
tmpl = template.Must(tmpl.ParseFiles(filepath.Join(".", "templates", "base.gohtml")))
// Parse all components next so they are available to the page
if len(components) > 0 {
tmpl = template.Must(tmpl.ParseFiles(components...))
}
// Parse the page itself last
tmpl = template.Must(tmpl.ParseFiles(page))
renderer.templates[name] = tmpl
log.Printf("Loaded page template: %s", name)
}
})
return renderer
}
// ExecuteTemplate renders a named template into wr, returning early if context is cancelled
func (r *Renderer) ExecuteTemplate(ctx context.Context, wr io.Writer, name string, data any) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
tmpl, ok := r.templates[name]
if !ok {
return fmt.Errorf("template %s not found", name)
}
return tmpl.ExecuteTemplate(wr, "base.gohtml", data)
}
// ExecuteFragment renders a specific named block within a template (e.g. a component)
func (r *Renderer) ExecuteFragment(ctx context.Context, wr io.Writer, name string, block string, data any) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
tmpl, ok := r.templates[name]
if !ok {
return fmt.Errorf("template %s not found", name)
}
return tmpl.ExecuteTemplate(wr, block, data)
}