feat: add transactional InTx to playback and watchlist repos

This commit is contained in:
2026-05-28 12:17:19 +02:00
parent 4329bce4a7
commit dd4c7f80f3
6 changed files with 115 additions and 62 deletions

View File

@@ -89,6 +89,7 @@ type EpisodeData struct {
}
type PlaybackRepository interface {
InTx(ctx context.Context, fn func(ctx context.Context, repo PlaybackRepository) error) error
GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error)
GetContinueWatchingEntry(ctx context.Context, params db.GetContinueWatchingEntryParams) (db.ContinueWatchingEntry, error)
SaveWatchProgress(ctx context.Context, params db.SaveWatchProgressParams) error

View File

@@ -21,6 +21,7 @@ type WatchlistService interface {
}
type WatchlistRepository interface {
InTx(ctx context.Context, fn func(ctx context.Context, repo WatchlistRepository) error) error
UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error)
GetAnime(ctx context.Context, id int64) (db.Anime, error)
UpsertWatchListEntry(ctx context.Context, arg db.UpsertWatchListEntryParams) (db.WatchListEntry, error)

View File

@@ -2,16 +2,37 @@ package repository
import (
"context"
"database/sql"
"mal/internal/db"
"mal/internal/domain"
)
type playbackRepository struct {
sqlDB *sql.DB
queries *db.Queries
}
func NewPlaybackRepository(queries *db.Queries) domain.PlaybackRepository {
return &playbackRepository{queries: queries}
func NewPlaybackRepository(sqlDB *sql.DB, queries *db.Queries) domain.PlaybackRepository {
return &playbackRepository{sqlDB: sqlDB, queries: queries}
}
func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.PlaybackRepository) error) error {
if r.sqlDB == nil {
return fn(ctx, r)
}
tx, err := r.sqlDB.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
if err := fn(ctx, txRepo); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) {

View File

@@ -301,12 +301,13 @@ func (s *playbackService) BuildWatchData(ctx context.Context, animeID int, title
}
func (s *playbackService) CompleteAnime(ctx context.Context, userID string, animeID int64) error {
entry, err := s.repo.GetWatchListEntry(ctx, db.GetWatchListEntryParams{
if err := s.repo.InTx(ctx, func(txCtx context.Context, repo domain.PlaybackRepository) error {
entry, err := repo.GetWatchListEntry(txCtx, db.GetWatchListEntryParams{
UserID: userID,
AnimeID: animeID,
})
if err != nil || entry.Status != "completed" {
_, err = s.repo.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{
_, err = repo.UpsertWatchListEntry(txCtx, db.UpsertWatchListEntryParams{
ID: uuid.New().String(),
UserID: userID,
AnimeID: animeID,
@@ -319,20 +320,22 @@ func (s *playbackService) CompleteAnime(ctx context.Context, userID string, anim
}
}
if err := s.repo.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{
if err := repo.DeleteContinueWatchingEntry(txCtx, db.DeleteContinueWatchingEntryParams{
UserID: userID,
AnimeID: animeID,
}); err != nil {
return err
}
if err := s.repo.SaveWatchProgress(ctx, db.SaveWatchProgressParams{
return repo.SaveWatchProgress(txCtx, db.SaveWatchProgressParams{
UserID: userID,
AnimeID: animeID,
CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
})
}); err != nil {
return err
}
if err := s.auditSvc.Record(ctx, domain.AuditEvent{
UserID: userID,
Action: "watch_completed",

View File

@@ -2,16 +2,37 @@ package repository
import (
"context"
"database/sql"
"mal/internal/db"
"mal/internal/domain"
)
type watchlistRepository struct {
sqlDB *sql.DB
queries *db.Queries
}
func NewWatchlistRepository(queries *db.Queries) domain.WatchlistRepository {
return &watchlistRepository{queries: queries}
func NewWatchlistRepository(sqlDB *sql.DB, queries *db.Queries) domain.WatchlistRepository {
return &watchlistRepository{sqlDB: sqlDB, queries: queries}
}
func (r *watchlistRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.WatchlistRepository) error) error {
if r.sqlDB == nil {
return fn(ctx, r)
}
tx, err := r.sqlDB.BeginTx(ctx, nil)
if err != nil {
return err
}
txRepo := &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
if err := fn(ctx, txRepo); err != nil {
_ = tx.Rollback()
return err
}
return tx.Commit()
}
func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) {

View File

@@ -20,13 +20,16 @@ func NewWatchlistService(repo domain.WatchlistRepository, jikan *jikan.Client) d
}
func (s *watchlistService) UpdateEntry(ctx context.Context, userID string, animeID int64, status string) error {
_, err := s.repo.GetAnime(ctx, animeID)
if err != nil {
anime, err := s.jikan.GetAnimeByID(ctx, int(animeID))
if err != nil {
return err
anime, fetchErr := s.jikan.GetAnimeByID(ctx, int(animeID))
if fetchErr != nil {
// still allow status updates for already-known anime rows
anime = jikan.Anime{}
}
if _, err := s.repo.UpsertAnime(ctx, db.UpsertAnimeParams{
return s.repo.InTx(ctx, func(txCtx context.Context, repo domain.WatchlistRepository) error {
_, err := repo.GetAnime(txCtx, animeID)
if err != nil && fetchErr == nil {
if _, err := repo.UpsertAnime(txCtx, db.UpsertAnimeParams{
ID: int64(anime.MalID),
TitleOriginal: anime.Title,
TitleEnglish: sql.NullString{String: anime.TitleEnglish, Valid: anime.TitleEnglish != ""},
@@ -38,13 +41,14 @@ func (s *watchlistService) UpdateEntry(ctx context.Context, userID string, anime
}
}
_, err = s.repo.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{
_, err = repo.UpsertWatchListEntry(txCtx, db.UpsertWatchListEntryParams{
ID: uuid.New().String(),
UserID: userID,
AnimeID: animeID,
Status: status,
})
return err
})
}
func (s *watchlistService) RemoveEntry(ctx context.Context, userID string, animeID int64) error {
@@ -99,16 +103,18 @@ func (s *watchlistService) GetContinueWatchingEntry(ctx context.Context, userID
}
func (s *watchlistService) DeleteContinueWatching(ctx context.Context, userID string, animeID int64) error {
if err := s.repo.DeleteContinueWatchingEntry(ctx, db.DeleteContinueWatchingEntryParams{
return s.repo.InTx(ctx, func(txCtx context.Context, repo domain.WatchlistRepository) error {
if err := repo.DeleteContinueWatchingEntry(txCtx, db.DeleteContinueWatchingEntryParams{
UserID: userID,
AnimeID: animeID,
}); err != nil {
return err
}
return s.repo.SaveWatchProgress(ctx, db.SaveWatchProgressParams{
return repo.SaveWatchProgress(txCtx, db.SaveWatchProgressParams{
UserID: userID,
AnimeID: animeID,
CurrentEpisode: sql.NullInt64{Valid: false},
CurrentTimeSeconds: 0,
})
})
}