diff --git a/internal/playback/skip_segments.go b/internal/playback/skip_segments.go index f737a32..4c2a529 100644 --- a/internal/playback/skip_segments.go +++ b/internal/playback/skip_segments.go @@ -66,9 +66,18 @@ func (s *playbackService) fetchSkipSegments(ctx context.Context, userID string, segments, err := s.fetchAniSkipSegments(ctx, malID, episode) if err != nil { + overrides := s.loadSkipSegmentOverrides(ctx, userID, malID, episode) + if len(overrides) > 0 { + return mergeSkipSegments(nil, overrides), nil + } return nil, fmt.Errorf("aniskip: %w", err) } - return s.applySkipSegmentOverrides(ctx, segments, userID, malID, episode), nil + + overrides := s.loadSkipSegmentOverrides(ctx, userID, malID, episode) + if len(overrides) == 0 { + return segments, nil + } + return mergeSkipSegments(segments, overrides), nil } func (s *playbackService) fetchAniSkipSegments(ctx context.Context, malID int, episode string) ([]domain.SkipSegment, error) { @@ -139,28 +148,23 @@ func normalizeSkipSegmentLabel(skipType string) string { } } -func (s *playbackService) applySkipSegmentOverrides(ctx context.Context, segments []domain.SkipSegment, userID string, malID int, episode string) []domain.SkipSegment { +func (s *playbackService) loadSkipSegmentOverrides(ctx context.Context, userID string, malID int, episode string) map[string]domain.SkipSegment { epNum, err := strconv.ParseInt(strings.TrimSpace(episode), 10, 64) if userID == "" || err != nil || epNum <= 0 { - return segments + return nil } ok, err := s.repo.HasSkipSegmentOverrideTable(ctx) if err != nil || !ok { - return segments + return nil } overrides, err := s.repo.ListSkipSegmentOverrides(ctx, userID, int64(malID), epNum) if err != nil { - return segments + return nil } - overrideByType := buildOverrideSegments(overrides) - if len(overrideByType) == 0 { - return segments - } - - return mergeSkipSegments(segments, overrideByType) + return buildOverrideSegments(overrides) } func buildOverrideSegments(overrides []db.SkipSegmentOverrideRow) map[string]domain.SkipSegment { diff --git a/internal/playback/skip_segments_test.go b/internal/playback/skip_segments_test.go new file mode 100644 index 0000000..3a76e20 --- /dev/null +++ b/internal/playback/skip_segments_test.go @@ -0,0 +1,150 @@ +package playback + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "mal/internal/db" + "mal/internal/domain" +) + +type wantSkipSegment struct { + segmentType string + start float64 + end float64 + source string +} + +func TestFetchSkipSegmentsFallsBackToOverridesWhenAniSkipFails(t *testing.T) { + t.Parallel() + + svc := &playbackService{ + repo: &skipSegmentRepo{ + hasTable: true, + overrides: []db.SkipSegmentOverrideRow{{ + UserID: "user-1", + AnimeID: 2167, + Episode: 13, + SkipType: "op", + StartTime: 90, + EndTime: 180, + }}, + }, + httpClient: &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + })}, + } + + got, err := svc.fetchSkipSegments(context.Background(), "user-1", 2167, "13") + if err != nil { + t.Fatalf("fetchSkipSegments returned error with local fallback: %v", err) + } + if len(got) != 1 { + t.Fatalf("len(got) = %d, want 1", len(got)) + } + assertSkipSegment(t, got[0], wantSkipSegment{segmentType: "opening", start: 90, end: 180, source: "override"}) +} + +func TestFetchSkipSegmentsReturnsAniSkipErrorWhenNoOverrideFallbackExists(t *testing.T) { + t.Parallel() + + svc := &playbackService{ + repo: &skipSegmentRepo{hasTable: true}, + httpClient: &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: io.NopCloser(strings.NewReader("")), + Header: make(http.Header), + }, nil + })}, + } + + got, err := svc.fetchSkipSegments(context.Background(), "user-1", 2167, "13") + if err == nil { + t.Fatal("fetchSkipSegments returned nil error without local fallback") + } + if got != nil { + t.Fatalf("got segments = %+v, want nil", got) + } + if !strings.Contains(err.Error(), "aniskip: unexpected status: 500") { + t.Fatalf("err = %v, want ani-skip status context", err) + } +} + +func TestFetchSkipSegmentsMergesOverridesWhenAniSkipSucceeds(t *testing.T) { + t.Parallel() + + svc := &playbackService{ + repo: &skipSegmentRepo{ + hasTable: true, + overrides: []db.SkipSegmentOverrideRow{{ + UserID: "user-1", + AnimeID: 2167, + Episode: 13, + SkipType: "op", + StartTime: 95, + EndTime: 185, + }}, + }, + httpClient: &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ + "found": true, + "results": [ + {"skip_type": "op", "interval": {"start_time": 80, "end_time": 170}}, + {"skip_type": "ed", "interval": {"start_time": 1300, "end_time": 1390}} + ] + }`)), + Header: make(http.Header), + }, nil + })}, + } + + got, err := svc.fetchSkipSegments(context.Background(), "user-1", 2167, "13") + if err != nil { + t.Fatalf("fetchSkipSegments returned error: %v", err) + } + if len(got) != 2 { + t.Fatalf("len(got) = %d, want 2", len(got)) + } + assertSkipSegment(t, got[0], wantSkipSegment{segmentType: "opening", start: 95, end: 185, source: "override"}) + assertSkipSegment(t, got[1], wantSkipSegment{segmentType: "ending", start: 1300, end: 1390, source: "aniskip"}) +} + +func assertSkipSegment(t *testing.T, got domain.SkipSegment, want wantSkipSegment) { + t.Helper() + + if got.Type != want.segmentType || got.Start != want.start || got.End != want.end || got.Source != want.source { + t.Fatalf("got segment = %+v, want type=%q start=%v end=%v source=%q", got, want.segmentType, want.start, want.end, want.source) + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +type skipSegmentRepo struct { + domain.PlaybackRepository + hasTable bool + hasErr error + overrides []db.SkipSegmentOverrideRow + listErr error +} + +func (r *skipSegmentRepo) HasSkipSegmentOverrideTable(context.Context) (bool, error) { + return r.hasTable, r.hasErr +} + +func (r *skipSegmentRepo) ListSkipSegmentOverrides(context.Context, string, int64, int64) ([]db.SkipSegmentOverrideRow, error) { + return r.overrides, r.listErr +}