From ede17ce8aaccee07de7c2d6de12d2c8abd56cf44 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Fri, 5 Jun 2026 16:38:27 +0200 Subject: [PATCH] test: verify diversity reranker spreads repeated genres --- internal/anime/recommendations_test.go | 58 ++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/internal/anime/recommendations_test.go b/internal/anime/recommendations_test.go index 7c9c5ea..4b327e5 100644 --- a/internal/anime/recommendations_test.go +++ b/internal/anime/recommendations_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "mal/integrations/jikan" "mal/internal/db" + "mal/internal/domain" "testing" "time" ) @@ -147,6 +148,63 @@ func TestBuildProfileSearchQueriesIncludesTasteSignals(t *testing.T) { } } +func TestRerankRecommendationCandidatesSpreadsRepeatedGenres(t *testing.T) { + const sportsGenreID = 30 + + candidates := []recommendationCandidate{ + {anime: testRecommendationAnime(1, sportsGenreID), score: 10}, + {anime: testRecommendationAnime(2, sportsGenreID), score: 9.9}, + {anime: testRecommendationAnime(3, sportsGenreID), score: 9.8}, + {anime: testRecommendationAnime(4, sportsGenreID), score: 9.7}, + {anime: testRecommendationAnime(5, sportsGenreID), score: 9.6}, + {anime: testRecommendationAnime(6, 1), score: 9.5}, + {anime: testRecommendationAnime(7, 2), score: 9.4}, + {anime: testRecommendationAnime(8, 3), score: 9.3}, + } + + reranked := rerankRecommendationCandidates(candidates, 8) + if len(reranked) < 5 { + t.Fatalf("expected enough reranked candidates, got %d", len(reranked)) + } + + for i := 0; i <= len(reranked)-5; i++ { + if allHaveGenre(reranked[i:i+5], sportsGenreID) { + t.Fatalf("expected reranker to avoid five sports anime in a row, got %+v", animeIDs(reranked)) + } + } +} + +func testRecommendationAnime(id int, genreID int) jikan.Anime { + return jikan.Anime{ + MalID: id, + Genres: []jikan.NamedEntity{{MalID: genreID, Name: "Genre"}}, + } +} + +func allHaveGenre(animes []domain.Anime, genreID int) bool { + for _, anime := range animes { + hasGenre := false + for _, genre := range anime.Genres { + if genre.MalID == genreID { + hasGenre = true + break + } + } + if !hasGenre { + return false + } + } + return true +} + +func animeIDs(animes []domain.Anime) []int { + ids := make([]int, 0, len(animes)) + for _, anime := range animes { + ids = append(ids, anime.MalID) + } + return ids +} + func hasGenreSearchQuery(queries []profileSearchQuery, genreID int) bool { for _, query := range queries { for _, id := range query.genreIDs {