diff --git a/pkg/middleware/ratelimit_test.go b/pkg/middleware/ratelimit_test.go new file mode 100644 index 0000000..04ea606 --- /dev/null +++ b/pkg/middleware/ratelimit_test.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func testHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} + +func TestLimiter(t *testing.T) { + cfg := Config{MaxAttempts: 2, Window: 100 * time.Millisecond} + l := NewLimiter(cfg) + handler := l.Middleware(testHandler()) + + // First attempt + req := httptest.NewRequest("GET", "/", nil) + req.RemoteAddr = "1.2.3.4:1234" + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + + // Second attempt + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + + // Third attempt (should fail) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", rr.Code) + } + + // Wait for window to expire + time.Sleep(150 * time.Millisecond) + + // Fourth attempt (should pass again) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } +}