From b9ad50b67adbbf3d11fcc901874dd042ebc2e13f Mon Sep 17 00:00:00 2001 From: mkelvers Date: Mon, 1 Jun 2026 22:22:14 +0200 Subject: [PATCH] refactor: dedupe repo tx --- internal/dbtx/tx.go | 25 +++++++++++++++++++++++++ internal/playback/repository.go | 20 ++++---------------- internal/watchlist/repository.go | 20 ++++---------------- 3 files changed, 33 insertions(+), 32 deletions(-) create mode 100644 internal/dbtx/tx.go diff --git a/internal/dbtx/tx.go b/internal/dbtx/tx.go new file mode 100644 index 0000000..2654ab5 --- /dev/null +++ b/internal/dbtx/tx.go @@ -0,0 +1,25 @@ +package dbtx + +import ( + "context" + "database/sql" +) + +func Run[T any](ctx context.Context, sqlDB *sql.DB, repo T, withTx func(*sql.Tx) T, fn func(context.Context, T) error) error { + if sqlDB == nil { + return fn(ctx, repo) + } + + tx, err := sqlDB.BeginTx(ctx, nil) + if err != nil { + return err + } + + txRepo := withTx(tx) + if err := fn(ctx, txRepo); err != nil { + _ = tx.Rollback() + return err + } + + return tx.Commit() +} diff --git a/internal/playback/repository.go b/internal/playback/repository.go index 947e0c4..02ecf69 100644 --- a/internal/playback/repository.go +++ b/internal/playback/repository.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "mal/internal/db" + "mal/internal/dbtx" "mal/internal/domain" ) @@ -17,22 +18,9 @@ func NewPlaybackRepository(sqlDB *sql.DB, queries *db.Queries) domain.PlaybackRe } func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.PlaybackRepository) error) error { - if r.sqlDB == nil { - return fn(ctx, r) - } - - tx, err := r.sqlDB.BeginTx(ctx, nil) - if err != nil { - return err - } - - txRepo := &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} - if err := fn(ctx, txRepo); err != nil { - _ = tx.Rollback() - return err - } - - return tx.Commit() + return dbtx.Run(ctx, r.sqlDB, domain.PlaybackRepository(r), func(tx *sql.Tx) domain.PlaybackRepository { + return &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} + }, fn) } func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) { diff --git a/internal/watchlist/repository.go b/internal/watchlist/repository.go index d78d6a8..62c5005 100644 --- a/internal/watchlist/repository.go +++ b/internal/watchlist/repository.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "mal/internal/db" + "mal/internal/dbtx" "mal/internal/domain" ) @@ -17,22 +18,9 @@ func NewWatchlistRepository(sqlDB *sql.DB, queries *db.Queries) domain.Watchlist } func (r *watchlistRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.WatchlistRepository) error) error { - if r.sqlDB == nil { - return fn(ctx, r) - } - - tx, err := r.sqlDB.BeginTx(ctx, nil) - if err != nil { - return err - } - - txRepo := &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} - if err := fn(ctx, txRepo); err != nil { - _ = tx.Rollback() - return err - } - - return tx.Commit() + return dbtx.Run(ctx, r.sqlDB, domain.WatchlistRepository(r), func(tx *sql.Tx) domain.WatchlistRepository { + return &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)} + }, fn) } func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) {