diff --git a/internal/domain/playback.go b/internal/domain/playback.go index bbf37ad..97360a8 100644 --- a/internal/domain/playback.go +++ b/internal/domain/playback.go @@ -89,6 +89,8 @@ type EpisodeData struct { type PlaybackRepository interface { InTx(ctx context.Context, fn func(ctx context.Context, repo PlaybackRepository) error) error + UpsertAnime(ctx context.Context, params db.UpsertAnimeParams) (db.Anime, error) + GetAnime(ctx context.Context, id int64) (db.Anime, error) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) GetContinueWatchingEntry(ctx context.Context, params db.GetContinueWatchingEntryParams) (db.ContinueWatchingEntry, error) SaveWatchProgress(ctx context.Context, params db.SaveWatchProgressParams) error diff --git a/internal/playback/progress.go b/internal/playback/progress.go index e9750de..8dab835 100644 --- a/internal/playback/progress.go +++ b/internal/playback/progress.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "strconv" "github.com/google/uuid" @@ -108,13 +109,22 @@ func (s *playbackService) CompleteAnime(ctx context.Context, userID string, anim } func (s *playbackService) SaveProgress(ctx context.Context, userID string, animeID int64, episode int, timeSeconds float64) error { - _, err := s.repo.UpsertContinueWatchingEntry(ctx, db.UpsertContinueWatchingEntryParams{ - ID: uuid.New().String(), - UserID: userID, - AnimeID: animeID, - CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true}, - CurrentTimeSeconds: timeSeconds, - DurationSeconds: sql.NullFloat64{Valid: false}, + err := s.repo.InTx(ctx, func(txCtx context.Context, repo domain.PlaybackRepository) error { + if _, err := repo.GetAnime(txCtx, animeID); err != nil { + if _, err := repo.UpsertAnime(txCtx, minimalAnimeParams(animeID)); err != nil { + return err + } + } + + _, err := repo.UpsertContinueWatchingEntry(txCtx, db.UpsertContinueWatchingEntryParams{ + ID: uuid.New().String(), + UserID: userID, + AnimeID: animeID, + CurrentEpisode: sql.NullInt64{Int64: int64(episode), Valid: true}, + CurrentTimeSeconds: timeSeconds, + DurationSeconds: sql.NullFloat64{Valid: false}, + }) + return err }) if err != nil { return err @@ -148,3 +158,36 @@ func (s *playbackService) SaveProgress(ctx context.Context, userID string, anime }) return nil } + +func (s *playbackService) ensureAnimeRow(ctx context.Context, anime domain.Anime) { + if _, err := s.repo.GetAnime(ctx, int64(anime.MalID)); err == nil { + return + } + _, _ = s.repo.UpsertAnime(ctx, animeParams(anime)) +} + +func animeParams(anime domain.Anime) db.UpsertAnimeParams { + durationSeconds := anime.DurationSeconds() + duration := sql.NullFloat64{Valid: durationSeconds > 0} + if duration.Valid { + duration.Float64 = durationSeconds + } + + return db.UpsertAnimeParams{ + ID: int64(anime.MalID), + TitleOriginal: anime.Title, + TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""}, + TitleJapanese: sql.NullString{String: anime.TitleJapanese, Valid: anime.TitleJapanese != ""}, + ImageUrl: anime.ImageURL(), + Airing: sql.NullBool{Bool: anime.Airing, Valid: true}, + DurationSeconds: duration, + } +} + +func minimalAnimeParams(animeID int64) db.UpsertAnimeParams { + return db.UpsertAnimeParams{ + ID: animeID, + TitleOriginal: fmt.Sprintf("Anime %d", animeID), + Airing: sql.NullBool{Valid: false}, + } +} diff --git a/internal/playback/repository.go b/internal/playback/repository.go index 02ecf69..a13e588 100644 --- a/internal/playback/repository.go +++ b/internal/playback/repository.go @@ -23,6 +23,14 @@ func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Conte }, fn) } +func (r *playbackRepository) UpsertAnime(ctx context.Context, params db.UpsertAnimeParams) (db.Anime, error) { + return r.queries.UpsertAnime(ctx, params) +} + +func (r *playbackRepository) GetAnime(ctx context.Context, id int64) (db.Anime, error) { + return r.queries.GetAnime(ctx, id) +} + func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) { return r.queries.GetWatchListEntry(ctx, params) } diff --git a/internal/playback/watch_data.go b/internal/playback/watch_data.go index 2b7836a..54b8c90 100644 --- a/internal/playback/watch_data.go +++ b/internal/playback/watch_data.go @@ -18,6 +18,7 @@ func (s *playbackService) BuildWatchData(ctx context.Context, animeID int, title } animeData := domain.Anime{Anime: anime} + s.ensureAnimeRow(ctx, animeData) searchTitles := buildSearchTitles(animeData, titleCandidates) canonicalEpisodes, err := s.episodes.GetCanonicalEpisodes(ctx, animeData, false) if err != nil {