feat(auth): implement strict and secure user registration

This commit is contained in:
2026-04-08 15:37:32 +02:00
parent fd9aca9ffc
commit 91e10560a6
8 changed files with 304 additions and 11 deletions

View File

@@ -8,7 +8,9 @@ import (
"errors"
"fmt"
"net/http"
"os"
"time"
"unicode"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
@@ -20,6 +22,7 @@ var (
ErrInvalidCredentials = errors.New("invalid username or password")
ErrUserExists = errors.New("username already exists")
ErrNotAuthenticated = errors.New("not authenticated")
ErrInvalidPassword = errors.New("password does not meet security requirements")
)
type Service struct {
@@ -39,8 +42,37 @@ func generateSessionToken() (string, error) {
return base64.URLEncoding.EncodeToString(b), nil
}
func ValidatePassword(password string) error {
if len(password) < 12 {
return fmt.Errorf("password must be at least 12 characters long")
}
var hasUpper, hasLower, hasNumber, hasSpecial bool
for _, c := range password {
switch {
case unicode.IsNumber(c):
hasNumber = true
case unicode.IsUpper(c):
hasUpper = true
case unicode.IsLower(c):
hasLower = true
case unicode.IsPunct(c) || unicode.IsSymbol(c) || !unicode.IsLetter(c) && !unicode.IsNumber(c):
hasSpecial = true
}
}
if !hasUpper || !hasLower || !hasNumber || !hasSpecial {
return fmt.Errorf("password must contain at least one uppercase letter, one lowercase letter, one number, and one special character")
}
return nil
}
func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err := ValidatePassword(password); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) // higher cost
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
@@ -117,25 +149,27 @@ func (s *Service) ValidateSession(ctx context.Context, sessionID string) (*datab
}
func SetSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Time) {
isProd := os.Getenv("ENV") == "production"
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Expires: expiresAt,
HttpOnly: true,
Secure: false, // False for local development without TLS
SameSite: http.SameSiteLaxMode,
Secure: isProd,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
}
func ClearSessionCookie(w http.ResponseWriter) {
isProd := os.Getenv("ENV") == "production"
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "",
Expires: time.Unix(0, 0),
HttpOnly: true,
Secure: false, // False for local development without TLS
SameSite: http.SameSiteLaxMode,
Secure: isProd,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
}

View File

@@ -39,6 +39,47 @@ func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", http.StatusFound)
}
func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
http.Error(w, "username and password are required", http.StatusBadRequest)
return
}
_, err := h.authService.RegisterUser(r.Context(), username, password)
if err != nil {
if err == ErrInvalidPassword {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err == ErrUserExists {
http.Error(w, "username already taken", http.StatusConflict)
return
}
http.Error(w, "registration failed", http.StatusInternalServerError)
return
}
// Auto-login after successful registration
session, err := h.authService.Login(r.Context(), username, password)
if err != nil {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
SetSessionCookie(w, session.ID, session.ExpiresAt)
w.Header().Set("HX-Redirect", "/")
http.Redirect(w, r, "/", http.StatusFound)
}
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err == nil {
@@ -53,3 +94,7 @@ func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
templates.Login().Render(r.Context(), w)
}
func (h *Handler) HandleRegisterPage(w http.ResponseWriter, r *http.Request) {
templates.Register().Render(r.Context(), w)
}

View File

@@ -53,10 +53,19 @@ func NewRouter(cfg Config) http.Handler {
if r.Method == http.MethodGet {
authHandler.HandleLoginPage(w, r)
} else {
authHandler.HandleLogin(w, r)
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogin))).ServeHTTP(w, r)
}
})
mux.HandleFunc("/logout", authHandler.HandleLogout)
mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
authHandler.HandleRegisterPage(w, r)
} else {
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRegister))).ServeHTTP(w, r)
}
})
mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogout)).ServeHTTP(w, r)
})
// Watchlist POST endpoint (Protected)
mux.Handle("/api/watchlist/export", middleware.RequireAuth(http.HandlerFunc(watchlistHandler.HandleExportWatchlist)))

View File

@@ -61,8 +61,8 @@ func RequireAuth(next http.Handler) http.Handler {
// RequireGlobalAuth ensures that a valid user is in the context for all routes except login and static
func RequireGlobalAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Allow unauthenticated access to login, search, and static files
if r.URL.Path == "/login" || strings.HasPrefix(r.URL.Path, "/static/") ||
// Allow unauthenticated access to login, register, search, and static files
if r.URL.Path == "/login" || r.URL.Path == "/register" || strings.HasPrefix(r.URL.Path, "/static/") ||
r.URL.Path == "/search" || r.URL.Path == "/api/search" || r.URL.Path == "/api/search-quick" ||
r.URL.Path == "/" {
next.ServeHTTP(w, r)

View File

@@ -0,0 +1,48 @@
package middleware
import (
"net/http"
"net/url"
)
// VerifyOrigin prevents simple CSRF by ensuring the Origin or Referer header matches the Host header
// for state-changing endpoints (POST/PUT/DELETE).
func VerifyOrigin(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
next.ServeHTTP(w, r)
return
}
origin := r.Header.Get("Origin")
if origin == "" {
referer := r.Header.Get("Referer")
if referer == "" {
// If neither is present, and it's a POST/PUT/DELETE, reject it (strict policy)
http.Error(w, "Missing Origin or Referer header", http.StatusForbidden)
return
}
refURL, err := url.Parse(referer)
if err != nil {
http.Error(w, "Invalid Referer header", http.StatusForbidden)
return
}
origin = refURL.Scheme + "://" + refURL.Host
}
host := r.Host
// Optional: strip port if you only care about domain
// If origin doesn't match host (accounting for potential schema prefixes)
expectedHTTP := "http://" + host
expectedHTTPS := "https://" + host
if origin != expectedHTTP && origin != expectedHTTPS {
http.Error(w, "Cross-Site Request Forgery (CSRF) origin mismatch", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,81 @@
package middleware
import (
"net/http"
"strings"
"sync"
"time"
)
type visitor struct {
attempts int
lastSeen time.Time
}
var (
visitors = make(map[string]*visitor)
mu sync.Mutex
)
func init() {
go cleanupVisitors()
}
func cleanupVisitors() {
for {
time.Sleep(time.Minute)
mu.Lock()
for ip, v := range visitors {
if time.Since(v.lastSeen) > 3*time.Minute {
delete(visitors, ip)
}
}
mu.Unlock()
}
}
// getIP attempts to get the real IP, falling back to RemoteAddr
func getIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
return strings.TrimSpace(ips[0])
}
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}
ip := r.RemoteAddr
if colonIdx := strings.LastIndex(ip, ":"); colonIdx != -1 {
ip = ip[:colonIdx]
}
return ip
}
// RateLimitAuth limits login/register attempts to prevent brute force
func RateLimitAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getIP(r)
mu.Lock()
v, exists := visitors[ip]
if !exists {
visitors[ip] = &visitor{1, time.Now()}
} else {
// Reset attempts if it's been more than a minute
if time.Since(v.lastSeen) > time.Minute {
v.attempts = 0
}
v.attempts++
v.lastSeen = time.Now()
}
// If more than 5 attempts within a minute, block
if exists && v.attempts > 5 {
mu.Unlock()
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
return
}
mu.Unlock()
next.ServeHTTP(w, r)
})
}

View File

@@ -7,7 +7,7 @@ templ Login() {
<p class="login-subtitle">enter your credentials to continue</p>
<form action="/login" method="POST" class="login-form">
<div class="form-group">
<label for="username">email</label>
<label for="username">username / email</label>
<input type="text" id="username" name="username" required placeholder="you@example.com"/>
</div>
<div class="form-group">
@@ -16,6 +16,35 @@ templ Login() {
</div>
<button type="submit" class="login-button">sign in</button>
</form>
<p style="margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);">
don't have an account? <a href="/register" style="color: var(--primary);">register</a>
</p>
</div>
}
}
templ Register() {
@Layout("Register") {
<div class="login-container">
<h2>register</h2>
<p class="login-subtitle">create a new account to track anime</p>
<form action="/register" method="POST" class="login-form">
<div class="form-group">
<label for="username">username / email</label>
<input type="text" id="username" name="username" required placeholder="you@example.com"/>
</div>
<div class="form-group">
<label for="password">password</label>
<input type="password" id="password" name="password" required placeholder="minimum 12 chars"/>
</div>
<p style="font-size: 0.75rem; color: var(--text-muted); margin-bottom: 1rem; line-height: 1.4;">
Password must be at least 12 characters and include an uppercase letter, lowercase letter, number, and special character.
</p>
<button type="submit" class="login-button">create account</button>
</form>
<p style="margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);">
already have an account? <a href="/login" style="color: var(--primary);">sign in</a>
</p>
</div>
}
}

View File

@@ -41,7 +41,7 @@ func Login() templ.Component {
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<div class=\"login-container\"><h2>sign in</h2><p class=\"login-subtitle\">enter your credentials to continue</p><form action=\"/login\" method=\"POST\" class=\"login-form\"><div class=\"form-group\"><label for=\"username\">email</label> <input type=\"text\" id=\"username\" name=\"username\" required placeholder=\"you@example.com\"></div><div class=\"form-group\"><label for=\"password\">password</label> <input type=\"password\" id=\"password\" name=\"password\" required placeholder=\"your password\"></div><button type=\"submit\" class=\"login-button\">sign in</button></form></div>")
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<div class=\"login-container\"><h2>sign in</h2><p class=\"login-subtitle\">enter your credentials to continue</p><form action=\"/login\" method=\"POST\" class=\"login-form\"><div class=\"form-group\"><label for=\"username\">username / email</label> <input type=\"text\" id=\"username\" name=\"username\" required placeholder=\"you@example.com\"></div><div class=\"form-group\"><label for=\"password\">password</label> <input type=\"password\" id=\"password\" name=\"password\" required placeholder=\"your password\"></div><button type=\"submit\" class=\"login-button\">sign in</button></form><p style=\"margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);\">don't have an account? <a href=\"/register\" style=\"color: var(--primary);\">register</a></p></div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
@@ -55,4 +55,51 @@ func Login() templ.Component {
})
}
func Register() templ.Component {
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
return templ_7745c5c3_CtxErr
}
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Var3 := templ.GetChildren(ctx)
if templ_7745c5c3_Var3 == nil {
templ_7745c5c3_Var3 = templ.NopComponent
}
ctx = templ.ClearChildren(ctx)
templ_7745c5c3_Var4 := templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
if !templ_7745c5c3_IsBuffer {
defer func() {
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
if templ_7745c5c3_Err == nil {
templ_7745c5c3_Err = templ_7745c5c3_BufErr
}
}()
}
ctx = templ.InitializeContext(ctx)
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "<div class=\"login-container\"><h2>register</h2><p class=\"login-subtitle\">create a new account to track anime</p><form action=\"/register\" method=\"POST\" class=\"login-form\"><div class=\"form-group\"><label for=\"username\">username / email</label> <input type=\"text\" id=\"username\" name=\"username\" required placeholder=\"you@example.com\"></div><div class=\"form-group\"><label for=\"password\">password</label> <input type=\"password\" id=\"password\" name=\"password\" required placeholder=\"minimum 12 chars\"></div><p style=\"font-size: 0.75rem; color: var(--text-muted); margin-bottom: 1rem; line-height: 1.4;\">Password must be at least 12 characters and include an uppercase letter, lowercase letter, number, and special character.</p><button type=\"submit\" class=\"login-button\">create account</button></form><p style=\"margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);\">already have an account? <a href=\"/login\" style=\"color: var(--primary);\">sign in</a></p></div>")
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
templ_7745c5c3_Err = Layout("Register").Render(templ.WithChildren(ctx, templ_7745c5c3_Var4), templ_7745c5c3_Buffer)
if templ_7745c5c3_Err != nil {
return templ_7745c5c3_Err
}
return nil
})
}
var _ = templruntime.GeneratedTemplate