test: add unit tests for rate limiter
This commit is contained in:
53
pkg/middleware/ratelimit_test.go
Normal file
53
pkg/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimiter(t *testing.T) {
|
||||||
|
cfg := Config{MaxAttempts: 2, Window: 100 * time.Millisecond}
|
||||||
|
l := NewLimiter(cfg)
|
||||||
|
handler := l.Middleware(testHandler())
|
||||||
|
|
||||||
|
// First attempt
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.RemoteAddr = "1.2.3.4:1234"
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second attempt
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third attempt (should fail)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected 429, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for window to expire
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
|
||||||
|
// Fourth attempt (should pass again)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user