diff --git a/internal/playback/repository.go b/internal/playback/repository.go index a13e588..b5575aa 100644 --- a/internal/playback/repository.go +++ b/internal/playback/repository.go @@ -3,8 +3,8 @@ package playback import ( "context" "database/sql" + "errors" "mal/internal/db" - "mal/internal/dbtx" "mal/internal/domain" ) @@ -18,9 +18,24 @@ func NewPlaybackRepository(sqlDB *sql.DB, queries *db.Queries) domain.PlaybackRe } func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.PlaybackRepository) error) error { - return dbtx.Run(ctx, r.sqlDB, domain.PlaybackRepository(r), func(tx *sql.Tx) domain.PlaybackRepository { - return &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} - }, fn) + if r.sqlDB == nil { + return fn(ctx, r) + } + + tx, err := r.sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + + txRepo := &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} + if err := fn(ctx, txRepo); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return errors.Join(err, rollbackErr) + } + return err + } + + return tx.Commit() } func (r *playbackRepository) UpsertAnime(ctx context.Context, params db.UpsertAnimeParams) (db.Anime, error) { diff --git a/internal/dbtx/tx.go b/internal/tx.go similarity index 96% rename from internal/dbtx/tx.go rename to internal/tx.go index c4cd332..a43b4ea 100644 --- a/internal/dbtx/tx.go +++ b/internal/tx.go @@ -1,4 +1,4 @@ -package dbtx +package internal import ( "context" diff --git a/internal/watchlist/repository.go b/internal/watchlist/repository.go index 62c5005..a5443fc 100644 --- a/internal/watchlist/repository.go +++ b/internal/watchlist/repository.go @@ -3,8 +3,8 @@ package watchlist import ( "context" "database/sql" + "errors" "mal/internal/db" - "mal/internal/dbtx" "mal/internal/domain" ) @@ -18,9 +18,24 @@ func NewWatchlistRepository(sqlDB *sql.DB, queries *db.Queries) domain.Watchlist } func (r *watchlistRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.WatchlistRepository) error) error { - return dbtx.Run(ctx, r.sqlDB, domain.WatchlistRepository(r), func(tx *sql.Tx) domain.WatchlistRepository { - return &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} - }, fn) + if r.sqlDB == nil { + return fn(ctx, r) + } + + tx, err := r.sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + + txRepo := &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} + if err := fn(ctx, txRepo); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return errors.Join(err, rollbackErr) + } + return err + } + + return tx.Commit() } func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) {