diff --git a/api/watchlist/service.go b/api/watchlist/service.go index a9f29ce..c78704e 100644 --- a/api/watchlist/service.go +++ b/api/watchlist/service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log" "strconv" "strings" @@ -191,6 +192,12 @@ func (s *Service) DeleteContinueWatching(ctx context.Context, userID string, ani } func (s *Service) ImportWatchlist(ctx context.Context, userID string, r io.Reader) error { + txQueries, tx, err := db.BeginTx(ctx, s.sqlDB) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + reader := csv.NewReader(r) // Read header if _, err := reader.Read(); err != nil { @@ -204,12 +211,13 @@ func (s *Service) ImportWatchlist(ctx context.Context, userID string, r io.Reade for i, record := range records { if len(record) < 4 { - continue // Skip malformed rows + log.Printf("skipping row %d: insufficient columns", i+2) // i+2 because i is 0-indexed record after header + continue } animeID, err := strconv.ParseInt(record[0], 10, 64) if err != nil { - return fmt.Errorf("row %d: invalid anime id: %w", i+1, err) + return fmt.Errorf("row %d: invalid anime id: %w", i+2, err) } status := record[1] @@ -221,10 +229,10 @@ func (s *Service) ImportWatchlist(ctx context.Context, userID string, r io.Reade currentTimeSeconds, _ := strconv.ParseFloat(record[3], 64) if err := s.ensureAnimeExists(ctx, animeID); err != nil { - return fmt.Errorf("row %d: failed to ensure anime: %w", i+1, err) + return fmt.Errorf("row %d: failed to ensure anime: %w", i+2, err) } - _, err = s.db.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{ + _, err = txQueries.UpsertWatchListEntry(ctx, db.UpsertWatchListEntryParams{ ID: uuid.New().String(), UserID: userID, AnimeID: animeID, @@ -233,9 +241,9 @@ func (s *Service) ImportWatchlist(ctx context.Context, userID string, r io.Reade CurrentTimeSeconds: currentTimeSeconds, }) if err != nil { - return fmt.Errorf("row %d: failed to upsert entry: %w", i+1, err) + return fmt.Errorf("row %d: failed to upsert entry: %w", i+2, err) } } - return nil + return tx.Commit() } diff --git a/api/watchlist/service_test.go b/api/watchlist/service_test.go index 093a52b..d6cb094 100644 --- a/api/watchlist/service_test.go +++ b/api/watchlist/service_test.go @@ -2,10 +2,14 @@ package watchlist import ( "context" + "database/sql" + "os" "strings" "testing" "mal/internal/db" + + _ "github.com/mattn/go-sqlite3" ) type fakeQuerier struct { @@ -73,29 +77,84 @@ func TestAddEntry_RejectsInvalidStatus(t *testing.T) { } func TestImportWatchlist(t *testing.T) { - t.Parallel() + dbFile := "test_watchlist.db" + defer os.Remove(dbFile) - q := &fakeQuerier{} - svc := NewService(q, nil, nil) + sqlDB, err := sql.Open("sqlite3", dbFile) + if err != nil { + t.Fatal(err) + } + defer sqlDB.Close() + + // Minimal schema for testing + _, err = sqlDB.Exec(` + CREATE TABLE anime ( + id INTEGER PRIMARY KEY, + title_original TEXT NOT NULL, + image_url TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + title_english TEXT, + title_japanese TEXT, + airing BOOLEAN, + status TEXT, + relations_synced_at DATETIME, + duration_seconds REAL + ); + CREATE TABLE watch_list_entry ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + anime_id INTEGER NOT NULL REFERENCES anime(id), + status TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + current_episode INTEGER DEFAULT 0, + last_episode_at DATETIME, + current_time_seconds REAL NOT NULL DEFAULT 0, + UNIQUE(user_id, anime_id) + ); + `) + if err != nil { + t.Fatal(err) + } + + queries := db.New(sqlDB) + svc := NewService(queries, sqlDB, nil) + + // Pre-insert anime so ensureAnimeExists succeeds + _, err = sqlDB.Exec(`INSERT INTO anime (id, title_original, image_url) VALUES (1, 'Test 1', '');`) + if err != nil { + t.Fatal(err) + } + _, err = sqlDB.Exec(`INSERT INTO anime (id, title_original, image_url) VALUES (2, 'Test 2', '');`) + if err != nil { + t.Fatal(err) + } csvData := `anime_id,status,current_episode,current_time_seconds 1,watching,5,120.5 2,invalid,10,0 ` - err := svc.ImportWatchlist(context.Background(), "user-1", strings.NewReader(csvData)) + err = svc.ImportWatchlist(context.Background(), "user-1", strings.NewReader(csvData)) if err != nil { t.Fatalf("ImportWatchlist failed: %v", err) } - if !q.upsertEntryCalled { - t.Fatal("expected entries to be upserted") + // Verify entries + var count int + err = sqlDB.QueryRow("SELECT COUNT(*) FROM watch_list_entry WHERE user_id = 'user-1'").Scan(&count) + if err != nil { + t.Fatal(err) + } + if count != 2 { + t.Errorf("expected 2 entries, got %d", count) } - // Verify the second record with invalid status was defaulted - // Note: We need a way to track all calls if we want to check the second record specifically, - // but the current fake only tracks the last call. - // For now, let's just check the last call which was record 2. - if q.upsertEntryParams.Status != "plan_to_watch" { - t.Errorf("expected status to be defaulted to plan_to_watch, got %s", q.upsertEntryParams.Status) + var status string + err = sqlDB.QueryRow("SELECT status FROM watch_list_entry WHERE anime_id = 2").Scan(&status) + if err != nil { + t.Fatal(err) + } + if status != "plan_to_watch" { + t.Errorf("expected status to be defaulted to plan_to_watch, got %s", status) } } diff --git a/templates/watchlist.gohtml b/templates/watchlist.gohtml index b09415e..12147db 100644 --- a/templates/watchlist.gohtml +++ b/templates/watchlist.gohtml @@ -32,6 +32,19 @@ + +