refactor: dedupe repo tx
This commit is contained in:
25
internal/dbtx/tx.go
Normal file
25
internal/dbtx/tx.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package dbtx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func Run[T any](ctx context.Context, sqlDB *sql.DB, repo T, withTx func(*sql.Tx) T, fn func(context.Context, T) error) error {
|
||||
if sqlDB == nil {
|
||||
return fn(ctx, repo)
|
||||
}
|
||||
|
||||
tx, err := sqlDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txRepo := withTx(tx)
|
||||
if err := fn(ctx, txRepo); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"mal/internal/db"
|
||||
"mal/internal/dbtx"
|
||||
"mal/internal/domain"
|
||||
)
|
||||
|
||||
@@ -17,22 +18,9 @@ func NewPlaybackRepository(sqlDB *sql.DB, queries *db.Queries) domain.PlaybackRe
|
||||
}
|
||||
|
||||
func (r *playbackRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.PlaybackRepository) error) error {
|
||||
if r.sqlDB == nil {
|
||||
return fn(ctx, r)
|
||||
}
|
||||
|
||||
tx, err := r.sqlDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txRepo := &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
|
||||
if err := fn(ctx, txRepo); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return dbtx.Run(ctx, r.sqlDB, domain.PlaybackRepository(r), func(tx *sql.Tx) domain.PlaybackRepository {
|
||||
return &playbackRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
|
||||
}, fn)
|
||||
}
|
||||
|
||||
func (r *playbackRepository) GetWatchListEntry(ctx context.Context, params db.GetWatchListEntryParams) (db.WatchListEntry, error) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"mal/internal/db"
|
||||
"mal/internal/dbtx"
|
||||
"mal/internal/domain"
|
||||
)
|
||||
|
||||
@@ -17,22 +18,9 @@ func NewWatchlistRepository(sqlDB *sql.DB, queries *db.Queries) domain.Watchlist
|
||||
}
|
||||
|
||||
func (r *watchlistRepository) InTx(ctx context.Context, fn func(ctx context.Context, repo domain.WatchlistRepository) error) error {
|
||||
if r.sqlDB == nil {
|
||||
return fn(ctx, r)
|
||||
}
|
||||
|
||||
tx, err := r.sqlDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txRepo := &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
|
||||
if err := fn(ctx, txRepo); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return dbtx.Run(ctx, r.sqlDB, domain.WatchlistRepository(r), func(tx *sql.Tx) domain.WatchlistRepository {
|
||||
return &watchlistRepository{sqlDB: nil, queries: r.queries.WithTx(tx)}
|
||||
}, fn)
|
||||
}
|
||||
|
||||
func (r *watchlistRepository) UpsertAnime(ctx context.Context, arg db.UpsertAnimeParams) (db.Anime, error) {
|
||||
|
||||
Reference in New Issue
Block a user