diff --git a/internal/anime/recommendations.go b/internal/anime/recommendations.go index 033bd8f..4476e92 100644 --- a/internal/anime/recommendations.go +++ b/internal/anime/recommendations.go @@ -27,6 +27,11 @@ const ( forYouThemeMatchWeight = 1.0 forYouStudioMatchWeight = 0.7 forYouDemographicMatchWeight = 0.9 + forYouRecentDiversityWindow = 3 + forYouGenreDiversityPenalty = 1.7 + forYouThemeDiversityPenalty = 1.2 + forYouDemoDiversityPenalty = 1.0 + forYouStudioDiversityPenalty = 0.7 ) type recommendationSeed struct { @@ -321,53 +326,177 @@ func weightedEntityMatch(weights map[int]float64, entities []jikan.NamedEntity) func rerankRecommendationCandidates(candidates []recommendationCandidate, limit int) []domain.Anime { selected := make([]domain.Anime, 0, min(limit, len(candidates))) - seenGenres := make(map[int]int) + remaining := slices.Clone(candidates) + seenFeatures := newDiversityFeatureCounts() + recentFeatures := make([]diversityFeatureSet, 0, forYouRecentDiversityWindow) - for _, candidate := range candidates { - if len(selected) >= limit { - break - } + for len(selected) < limit && len(remaining) > 0 { + bestIndex := bestDiverseCandidateIndex(remaining, seenFeatures, recentFeatures) + candidate := remaining[bestIndex] + remaining = slices.Delete(remaining, bestIndex, bestIndex+1) - if isGenreOverrepresented(candidate.anime, seenGenres) { - continue - } - - selected = append(selected, domain.Anime{Anime: candidate.anime}) - for _, genre := range candidate.anime.Genres { - seenGenres[genre.MalID]++ - } - } - - if len(selected) >= limit { - return selected - } - - for _, candidate := range candidates { - if len(selected) >= limit { - break - } if slices.ContainsFunc(selected, func(anime domain.Anime) bool { return anime.MalID == candidate.anime.MalID }) { continue } + selected = append(selected, domain.Anime{Anime: candidate.anime}) + features := diversityFeatures(candidate.anime) + seenFeatures.add(features) + recentFeatures = append(recentFeatures, features) + if len(recentFeatures) > forYouRecentDiversityWindow { + recentFeatures = recentFeatures[1:] + } } return selected } -func isGenreOverrepresented(anime jikan.Anime, seenGenres map[int]int) bool { - if len(anime.Genres) == 0 { - return false - } +type diversityFeatureSet struct { + genres map[int]struct{} + themes map[int]struct{} + demographics map[int]struct{} + studios map[int]struct{} +} - matchedGenres := 0 - for _, genre := range anime.Genres { - if seenGenres[genre.MalID] >= 3 { - matchedGenres++ +type diversityFeatureCounts struct { + genres map[int]int + themes map[int]int + demographics map[int]int + studios map[int]int +} + +func newDiversityFeatureCounts() diversityFeatureCounts { + return diversityFeatureCounts{ + genres: make(map[int]int), + themes: make(map[int]int), + demographics: make(map[int]int), + studios: make(map[int]int), + } +} + +func (counts diversityFeatureCounts) add(features diversityFeatureSet) { + addDiversityCounts(counts.genres, features.genres) + addDiversityCounts(counts.themes, features.themes) + addDiversityCounts(counts.demographics, features.demographics) + addDiversityCounts(counts.studios, features.studios) +} + +func addDiversityCounts(target map[int]int, features map[int]struct{}) { + for id := range features { + target[id]++ + } +} + +func bestDiverseCandidateIndex( + candidates []recommendationCandidate, + seen diversityFeatureCounts, + recent []diversityFeatureSet, +) int { + bestIndex := 0 + bestScore := math.Inf(-1) + + for i, candidate := range candidates { + score := candidate.score - diversityPenalty(diversityFeatures(candidate.anime), seen, recent) + if score == bestScore { + if candidate.score <= candidates[bestIndex].score { + continue + } + } + if score > bestScore { + bestScore = score + bestIndex = i } } - return matchedGenres == len(anime.Genres) + return bestIndex +} + +func diversityFeatures(anime jikan.Anime) diversityFeatureSet { + return diversityFeatureSet{ + genres: entityIDSet(anime.Genres), + themes: entityIDSet(anime.Themes), + demographics: entityIDSet(anime.Demographics), + studios: entityIDSet(anime.Studios), + } +} + +func entityIDSet(entities []jikan.NamedEntity) map[int]struct{} { + ids := make(map[int]struct{}, len(entities)) + for _, entity := range entities { + if entity.MalID <= 0 { + continue + } + ids[entity.MalID] = struct{}{} + } + return ids +} + +func diversityPenalty( + features diversityFeatureSet, + seen diversityFeatureCounts, + recent []diversityFeatureSet, +) float64 { + penalty := 0.0 + penalty += repeatedFeaturePenalty(features.genres, seen.genres, recentGenreCounts(recent), forYouGenreDiversityPenalty) + penalty += repeatedFeaturePenalty(features.themes, seen.themes, recentThemeCounts(recent), forYouThemeDiversityPenalty) + penalty += repeatedFeaturePenalty( + features.demographics, + seen.demographics, + recentDemographicCounts(recent), + forYouDemoDiversityPenalty, + ) + penalty += repeatedFeaturePenalty(features.studios, seen.studios, recentStudioCounts(recent), forYouStudioDiversityPenalty) + + return penalty +} + +func repeatedFeaturePenalty( + features map[int]struct{}, + seen map[int]int, + recent map[int]int, + weight float64, +) float64 { + total := 0.0 + for id := range features { + total += float64(seen[id]) * weight * 0.35 + total += float64(recent[id]) * weight + } + return total +} + +func recentGenreCounts(recent []diversityFeatureSet) map[int]int { + return recentFeatureCounts(recent, func(features diversityFeatureSet) map[int]struct{} { + return features.genres + }) +} + +func recentThemeCounts(recent []diversityFeatureSet) map[int]int { + return recentFeatureCounts(recent, func(features diversityFeatureSet) map[int]struct{} { + return features.themes + }) +} + +func recentDemographicCounts(recent []diversityFeatureSet) map[int]int { + return recentFeatureCounts(recent, func(features diversityFeatureSet) map[int]struct{} { + return features.demographics + }) +} + +func recentStudioCounts(recent []diversityFeatureSet) map[int]int { + return recentFeatureCounts(recent, func(features diversityFeatureSet) map[int]struct{} { + return features.studios + }) +} + +func recentFeatureCounts( + recent []diversityFeatureSet, + selectFeatures func(diversityFeatureSet) map[int]struct{}, +) map[int]int { + counts := make(map[int]int) + for _, features := range recent { + addDiversityCounts(counts, selectFeatures(features)) + } + return counts }