From 91e10560a6470ce62b2168b397ef187bd678dea7 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Wed, 8 Apr 2026 15:37:32 +0200 Subject: [PATCH] feat(auth): implement strict and secure user registration --- internal/features/auth/auth.go | 44 ++++++++++++-- internal/features/auth/handler.go | 45 ++++++++++++++ internal/server/routes.go | 13 +++- internal/shared/middleware/auth.go | 4 +- internal/shared/middleware/csrf.go | 48 +++++++++++++++ internal/shared/middleware/ratelimit.go | 81 +++++++++++++++++++++++++ internal/templates/auth.templ | 31 +++++++++- internal/templates/auth_templ.go | 49 ++++++++++++++- 8 files changed, 304 insertions(+), 11 deletions(-) create mode 100644 internal/shared/middleware/csrf.go create mode 100644 internal/shared/middleware/ratelimit.go diff --git a/internal/features/auth/auth.go b/internal/features/auth/auth.go index e4df188..9c19004 100644 --- a/internal/features/auth/auth.go +++ b/internal/features/auth/auth.go @@ -8,7 +8,9 @@ import ( "errors" "fmt" "net/http" + "os" "time" + "unicode" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" @@ -20,6 +22,7 @@ var ( ErrInvalidCredentials = errors.New("invalid username or password") ErrUserExists = errors.New("username already exists") ErrNotAuthenticated = errors.New("not authenticated") + ErrInvalidPassword = errors.New("password does not meet security requirements") ) type Service struct { @@ -39,8 +42,37 @@ func generateSessionToken() (string, error) { return base64.URLEncoding.EncodeToString(b), nil } +func ValidatePassword(password string) error { + if len(password) < 12 { + return fmt.Errorf("password must be at least 12 characters long") + } + + var hasUpper, hasLower, hasNumber, hasSpecial bool + for _, c := range password { + switch { + case unicode.IsNumber(c): + hasNumber = true + case unicode.IsUpper(c): + hasUpper = true + case unicode.IsLower(c): + hasLower = true + case unicode.IsPunct(c) || unicode.IsSymbol(c) || !unicode.IsLetter(c) && !unicode.IsNumber(c): + hasSpecial = true + } + } + + if !hasUpper || !hasLower || !hasNumber || !hasSpecial { + return fmt.Errorf("password must contain at least one uppercase letter, one lowercase letter, one number, and one special character") + } + return nil +} + func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, error) { - hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err := ValidatePassword(password); err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidPassword, err) + } + + hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) // higher cost if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } @@ -117,25 +149,27 @@ func (s *Service) ValidateSession(ctx context.Context, sessionID string) (*datab } func SetSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Time) { + isProd := os.Getenv("ENV") == "production" http.SetCookie(w, &http.Cookie{ Name: "session_id", Value: sessionID, Expires: expiresAt, HttpOnly: true, - Secure: false, // False for local development without TLS - SameSite: http.SameSiteLaxMode, + Secure: isProd, + SameSite: http.SameSiteStrictMode, Path: "/", }) } func ClearSessionCookie(w http.ResponseWriter) { + isProd := os.Getenv("ENV") == "production" http.SetCookie(w, &http.Cookie{ Name: "session_id", Value: "", Expires: time.Unix(0, 0), HttpOnly: true, - Secure: false, // False for local development without TLS - SameSite: http.SameSiteLaxMode, + Secure: isProd, + SameSite: http.SameSiteStrictMode, Path: "/", }) } diff --git a/internal/features/auth/handler.go b/internal/features/auth/handler.go index 538ffc2..8e228a3 100644 --- a/internal/features/auth/handler.go +++ b/internal/features/auth/handler.go @@ -39,6 +39,47 @@ func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusFound) } +func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + + if username == "" || password == "" { + http.Error(w, "username and password are required", http.StatusBadRequest) + return + } + + _, err := h.authService.RegisterUser(r.Context(), username, password) + if err != nil { + if err == ErrInvalidPassword { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err == ErrUserExists { + http.Error(w, "username already taken", http.StatusConflict) + return + } + http.Error(w, "registration failed", http.StatusInternalServerError) + return + } + + // Auto-login after successful registration + session, err := h.authService.Login(r.Context(), username, password) + if err != nil { + http.Redirect(w, r, "/login", http.StatusFound) + return + } + + SetSessionCookie(w, session.ID, session.ExpiresAt) + + w.Header().Set("HX-Redirect", "/") + http.Redirect(w, r, "/", http.StatusFound) +} + func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("session_id") if err == nil { @@ -53,3 +94,7 @@ func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) { templates.Login().Render(r.Context(), w) } + +func (h *Handler) HandleRegisterPage(w http.ResponseWriter, r *http.Request) { + templates.Register().Render(r.Context(), w) +} diff --git a/internal/server/routes.go b/internal/server/routes.go index 15e8ebb..768ae2a 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -53,10 +53,19 @@ func NewRouter(cfg Config) http.Handler { if r.Method == http.MethodGet { authHandler.HandleLoginPage(w, r) } else { - authHandler.HandleLogin(w, r) + middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogin))).ServeHTTP(w, r) } }) - mux.HandleFunc("/logout", authHandler.HandleLogout) + mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + authHandler.HandleRegisterPage(w, r) + } else { + middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRegister))).ServeHTTP(w, r) + } + }) + mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) { + middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogout)).ServeHTTP(w, r) + }) // Watchlist POST endpoint (Protected) mux.Handle("/api/watchlist/export", middleware.RequireAuth(http.HandlerFunc(watchlistHandler.HandleExportWatchlist))) diff --git a/internal/shared/middleware/auth.go b/internal/shared/middleware/auth.go index 567d25d..7224d99 100644 --- a/internal/shared/middleware/auth.go +++ b/internal/shared/middleware/auth.go @@ -61,8 +61,8 @@ func RequireAuth(next http.Handler) http.Handler { // RequireGlobalAuth ensures that a valid user is in the context for all routes except login and static func RequireGlobalAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Allow unauthenticated access to login, search, and static files - if r.URL.Path == "/login" || strings.HasPrefix(r.URL.Path, "/static/") || + // Allow unauthenticated access to login, register, search, and static files + if r.URL.Path == "/login" || r.URL.Path == "/register" || strings.HasPrefix(r.URL.Path, "/static/") || r.URL.Path == "/search" || r.URL.Path == "/api/search" || r.URL.Path == "/api/search-quick" || r.URL.Path == "/" { next.ServeHTTP(w, r) diff --git a/internal/shared/middleware/csrf.go b/internal/shared/middleware/csrf.go new file mode 100644 index 0000000..9099c96 --- /dev/null +++ b/internal/shared/middleware/csrf.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "net/http" + "net/url" +) + +// VerifyOrigin prevents simple CSRF by ensuring the Origin or Referer header matches the Host header +// for state-changing endpoints (POST/PUT/DELETE). +func VerifyOrigin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return + } + + origin := r.Header.Get("Origin") + if origin == "" { + referer := r.Header.Get("Referer") + if referer == "" { + // If neither is present, and it's a POST/PUT/DELETE, reject it (strict policy) + http.Error(w, "Missing Origin or Referer header", http.StatusForbidden) + return + } + + refURL, err := url.Parse(referer) + if err != nil { + http.Error(w, "Invalid Referer header", http.StatusForbidden) + return + } + origin = refURL.Scheme + "://" + refURL.Host + } + + host := r.Host + // Optional: strip port if you only care about domain + + // If origin doesn't match host (accounting for potential schema prefixes) + expectedHTTP := "http://" + host + expectedHTTPS := "https://" + host + + if origin != expectedHTTP && origin != expectedHTTPS { + http.Error(w, "Cross-Site Request Forgery (CSRF) origin mismatch", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/shared/middleware/ratelimit.go b/internal/shared/middleware/ratelimit.go new file mode 100644 index 0000000..2d61366 --- /dev/null +++ b/internal/shared/middleware/ratelimit.go @@ -0,0 +1,81 @@ +package middleware + +import ( + "net/http" + "strings" + "sync" + "time" +) + +type visitor struct { + attempts int + lastSeen time.Time +} + +var ( + visitors = make(map[string]*visitor) + mu sync.Mutex +) + +func init() { + go cleanupVisitors() +} + +func cleanupVisitors() { + for { + time.Sleep(time.Minute) + mu.Lock() + for ip, v := range visitors { + if time.Since(v.lastSeen) > 3*time.Minute { + delete(visitors, ip) + } + } + mu.Unlock() + } +} + +// getIP attempts to get the real IP, falling back to RemoteAddr +func getIP(r *http.Request) string { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + ips := strings.Split(xff, ",") + return strings.TrimSpace(ips[0]) + } + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return realIP + } + ip := r.RemoteAddr + if colonIdx := strings.LastIndex(ip, ":"); colonIdx != -1 { + ip = ip[:colonIdx] + } + return ip +} + +// RateLimitAuth limits login/register attempts to prevent brute force +func RateLimitAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getIP(r) + + mu.Lock() + v, exists := visitors[ip] + if !exists { + visitors[ip] = &visitor{1, time.Now()} + } else { + // Reset attempts if it's been more than a minute + if time.Since(v.lastSeen) > time.Minute { + v.attempts = 0 + } + v.attempts++ + v.lastSeen = time.Now() + } + + // If more than 5 attempts within a minute, block + if exists && v.attempts > 5 { + mu.Unlock() + http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) + return + } + mu.Unlock() + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/templates/auth.templ b/internal/templates/auth.templ index 607f2f8..292c917 100644 --- a/internal/templates/auth.templ +++ b/internal/templates/auth.templ @@ -7,7 +7,7 @@ templ Login() {

enter your credentials to continue

- +
@@ -16,6 +16,35 @@ templ Login() {
+

+ don't have an account? register +

+ + } +} + +templ Register() { + @Layout("Register") { +
+

register

+ + +

+ already have an account? sign in +

} } diff --git a/internal/templates/auth_templ.go b/internal/templates/auth_templ.go index 6ca9262..19017e5 100644 --- a/internal/templates/auth_templ.go +++ b/internal/templates/auth_templ.go @@ -41,7 +41,7 @@ func Login() templ.Component { }() } ctx = templ.InitializeContext(ctx) - templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "

sign in

enter your credentials to continue

") + templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "

sign in

enter your credentials to continue

don't have an account? register

") if templ_7745c5c3_Err != nil { return templ_7745c5c3_Err } @@ -55,4 +55,51 @@ func Login() templ.Component { }) } +func Register() templ.Component { + return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { + templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context + if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil { + return templ_7745c5c3_CtxErr + } + templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W) + if !templ_7745c5c3_IsBuffer { + defer func() { + templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer) + if templ_7745c5c3_Err == nil { + templ_7745c5c3_Err = templ_7745c5c3_BufErr + } + }() + } + ctx = templ.InitializeContext(ctx) + templ_7745c5c3_Var3 := templ.GetChildren(ctx) + if templ_7745c5c3_Var3 == nil { + templ_7745c5c3_Var3 = templ.NopComponent + } + ctx = templ.ClearChildren(ctx) + templ_7745c5c3_Var4 := templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { + templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context + templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W) + if !templ_7745c5c3_IsBuffer { + defer func() { + templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer) + if templ_7745c5c3_Err == nil { + templ_7745c5c3_Err = templ_7745c5c3_BufErr + } + }() + } + ctx = templ.InitializeContext(ctx) + templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "

register

create a new account to track anime

Password must be at least 12 characters and include an uppercase letter, lowercase letter, number, and special character.

already have an account? sign in

") + if templ_7745c5c3_Err != nil { + return templ_7745c5c3_Err + } + return nil + }) + templ_7745c5c3_Err = Layout("Register").Render(templ.WithChildren(ctx, templ_7745c5c3_Var4), templ_7745c5c3_Buffer) + if templ_7745c5c3_Err != nil { + return templ_7745c5c3_Err + } + return nil + }) +} + var _ = templruntime.GeneratedTemplate