diff --git a/integrations/jikan/rate/limiter.go b/integrations/jikan/rate/limiter.go index 148ba69..5a0a4e1 100644 --- a/integrations/jikan/rate/limiter.go +++ b/integrations/jikan/rate/limiter.go @@ -9,7 +9,7 @@ import ( type Limiter struct { mu sync.Mutex - lastReqTime time.Time + nextReqTime time.Time interval time.Duration } @@ -19,24 +19,32 @@ func NewLimiter(interval time.Duration) *Limiter { // Wait enforces minimum spacing between upstream Jikan requests. func (l *Limiter) Wait(ctx context.Context) error { - l.mu.Lock() - defer l.mu.Unlock() - - now := time.Now() - nextAllowed := l.lastReqTime.Add(l.interval) - if now.Before(nextAllowed) { - timer := time.NewTimer(nextAllowed.Sub(now)) - defer timer.Stop() - - select { - case <-timer.C: - case <-ctx.Done(): - return fmt.Errorf("request canceled while waiting for rate limit: %w", ctx.Err()) - } - l.lastReqTime = time.Now() + waitUntil := l.reserve(time.Now()) + if waitUntil.IsZero() { return nil } - l.lastReqTime = now - return nil + timer := time.NewTimer(time.Until(waitUntil)) + defer timer.Stop() + + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return fmt.Errorf("request canceled while waiting for rate limit: %w", ctx.Err()) + } +} + +func (l *Limiter) reserve(now time.Time) time.Time { + l.mu.Lock() + defer l.mu.Unlock() + + if l.nextReqTime.IsZero() || now.After(l.nextReqTime) { + l.nextReqTime = now.Add(l.interval) + return time.Time{} + } + + waitUntil := l.nextReqTime + l.nextReqTime = l.nextReqTime.Add(l.interval) + return waitUntil } diff --git a/integrations/jikan/rate/limiter_test.go b/integrations/jikan/rate/limiter_test.go new file mode 100644 index 0000000..56d516a --- /dev/null +++ b/integrations/jikan/rate/limiter_test.go @@ -0,0 +1,40 @@ +package rate + +import ( + "context" + "testing" + "time" +) + +func TestLimiterDoesNotHoldLockWhileWaiting(t *testing.T) { + limiter := NewLimiter(250 * time.Millisecond) + if err := limiter.Wait(context.Background()); err != nil { + t.Fatalf("initial wait: %v", err) + } + + firstCtx, cancelFirst := context.WithCancel(context.Background()) + defer cancelFirst() + + firstDone := make(chan error, 1) + go func() { + firstDone <- limiter.Wait(firstCtx) + }() + + time.Sleep(20 * time.Millisecond) + + secondCtx, cancelSecond := context.WithTimeout(context.Background(), 30*time.Millisecond) + defer cancelSecond() + + startedAt := time.Now() + err := limiter.Wait(secondCtx) + elapsed := time.Since(startedAt) + if err == nil { + t.Fatal("second wait succeeded, want context timeout") + } + if elapsed > 150*time.Millisecond { + t.Fatalf("second wait took %s, want it to observe context timeout without waiting behind first caller", elapsed) + } + + cancelFirst() + <-firstDone +}