feat(auth): implement strict and secure user registration
This commit is contained in:
@@ -8,7 +8,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -20,6 +22,7 @@ var (
|
||||
ErrInvalidCredentials = errors.New("invalid username or password")
|
||||
ErrUserExists = errors.New("username already exists")
|
||||
ErrNotAuthenticated = errors.New("not authenticated")
|
||||
ErrInvalidPassword = errors.New("password does not meet security requirements")
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@@ -39,8 +42,37 @@ func generateSessionToken() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func ValidatePassword(password string) error {
|
||||
if len(password) < 12 {
|
||||
return fmt.Errorf("password must be at least 12 characters long")
|
||||
}
|
||||
|
||||
var hasUpper, hasLower, hasNumber, hasSpecial bool
|
||||
for _, c := range password {
|
||||
switch {
|
||||
case unicode.IsNumber(c):
|
||||
hasNumber = true
|
||||
case unicode.IsUpper(c):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(c):
|
||||
hasLower = true
|
||||
case unicode.IsPunct(c) || unicode.IsSymbol(c) || !unicode.IsLetter(c) && !unicode.IsNumber(c):
|
||||
hasSpecial = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasNumber || !hasSpecial {
|
||||
return fmt.Errorf("password must contain at least one uppercase letter, one lowercase letter, one number, and one special character")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err := ValidatePassword(password); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidPassword, err)
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) // higher cost
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
@@ -117,25 +149,27 @@ func (s *Service) ValidateSession(ctx context.Context, sessionID string) (*datab
|
||||
}
|
||||
|
||||
func SetSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Time) {
|
||||
isProd := os.Getenv("ENV") == "production"
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: sessionID,
|
||||
Expires: expiresAt,
|
||||
HttpOnly: true,
|
||||
Secure: false, // False for local development without TLS
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isProd,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
func ClearSessionCookie(w http.ResponseWriter) {
|
||||
isProd := os.Getenv("ENV") == "production"
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_id",
|
||||
Value: "",
|
||||
Expires: time.Unix(0, 0),
|
||||
HttpOnly: true,
|
||||
Secure: false, // False for local development without TLS
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: isProd,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -39,6 +39,47 @@ func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
if username == "" || password == "" {
|
||||
http.Error(w, "username and password are required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
_, err := h.authService.RegisterUser(r.Context(), username, password)
|
||||
if err != nil {
|
||||
if err == ErrInvalidPassword {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err == ErrUserExists {
|
||||
http.Error(w, "username already taken", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, "registration failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Auto-login after successful registration
|
||||
session, err := h.authService.Login(r.Context(), username, password)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
SetSessionCookie(w, session.ID, session.ExpiresAt)
|
||||
|
||||
w.Header().Set("HX-Redirect", "/")
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_id")
|
||||
if err == nil {
|
||||
@@ -53,3 +94,7 @@ func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
templates.Login().Render(r.Context(), w)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleRegisterPage(w http.ResponseWriter, r *http.Request) {
|
||||
templates.Register().Render(r.Context(), w)
|
||||
}
|
||||
|
||||
@@ -53,10 +53,19 @@ func NewRouter(cfg Config) http.Handler {
|
||||
if r.Method == http.MethodGet {
|
||||
authHandler.HandleLoginPage(w, r)
|
||||
} else {
|
||||
authHandler.HandleLogin(w, r)
|
||||
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogin))).ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/logout", authHandler.HandleLogout)
|
||||
mux.HandleFunc("/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet {
|
||||
authHandler.HandleRegisterPage(w, r)
|
||||
} else {
|
||||
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRegister))).ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogout)).ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
// Watchlist POST endpoint (Protected)
|
||||
mux.Handle("/api/watchlist/export", middleware.RequireAuth(http.HandlerFunc(watchlistHandler.HandleExportWatchlist)))
|
||||
|
||||
@@ -61,8 +61,8 @@ func RequireAuth(next http.Handler) http.Handler {
|
||||
// RequireGlobalAuth ensures that a valid user is in the context for all routes except login and static
|
||||
func RequireGlobalAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Allow unauthenticated access to login, search, and static files
|
||||
if r.URL.Path == "/login" || strings.HasPrefix(r.URL.Path, "/static/") ||
|
||||
// Allow unauthenticated access to login, register, search, and static files
|
||||
if r.URL.Path == "/login" || r.URL.Path == "/register" || strings.HasPrefix(r.URL.Path, "/static/") ||
|
||||
r.URL.Path == "/search" || r.URL.Path == "/api/search" || r.URL.Path == "/api/search-quick" ||
|
||||
r.URL.Path == "/" {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
48
internal/shared/middleware/csrf.go
Normal file
48
internal/shared/middleware/csrf.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// VerifyOrigin prevents simple CSRF by ensuring the Origin or Referer header matches the Host header
|
||||
// for state-changing endpoints (POST/PUT/DELETE).
|
||||
func VerifyOrigin(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
referer := r.Header.Get("Referer")
|
||||
if referer == "" {
|
||||
// If neither is present, and it's a POST/PUT/DELETE, reject it (strict policy)
|
||||
http.Error(w, "Missing Origin or Referer header", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
refURL, err := url.Parse(referer)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid Referer header", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
origin = refURL.Scheme + "://" + refURL.Host
|
||||
}
|
||||
|
||||
host := r.Host
|
||||
// Optional: strip port if you only care about domain
|
||||
|
||||
// If origin doesn't match host (accounting for potential schema prefixes)
|
||||
expectedHTTP := "http://" + host
|
||||
expectedHTTPS := "https://" + host
|
||||
|
||||
if origin != expectedHTTP && origin != expectedHTTPS {
|
||||
http.Error(w, "Cross-Site Request Forgery (CSRF) origin mismatch", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
81
internal/shared/middleware/ratelimit.go
Normal file
81
internal/shared/middleware/ratelimit.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type visitor struct {
|
||||
attempts int
|
||||
lastSeen time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
visitors = make(map[string]*visitor)
|
||||
mu sync.Mutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
go cleanupVisitors()
|
||||
}
|
||||
|
||||
func cleanupVisitors() {
|
||||
for {
|
||||
time.Sleep(time.Minute)
|
||||
mu.Lock()
|
||||
for ip, v := range visitors {
|
||||
if time.Since(v.lastSeen) > 3*time.Minute {
|
||||
delete(visitors, ip)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// getIP attempts to get the real IP, falling back to RemoteAddr
|
||||
func getIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
ips := strings.Split(xff, ",")
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
||||
return realIP
|
||||
}
|
||||
ip := r.RemoteAddr
|
||||
if colonIdx := strings.LastIndex(ip, ":"); colonIdx != -1 {
|
||||
ip = ip[:colonIdx]
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// RateLimitAuth limits login/register attempts to prevent brute force
|
||||
func RateLimitAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip := getIP(r)
|
||||
|
||||
mu.Lock()
|
||||
v, exists := visitors[ip]
|
||||
if !exists {
|
||||
visitors[ip] = &visitor{1, time.Now()}
|
||||
} else {
|
||||
// Reset attempts if it's been more than a minute
|
||||
if time.Since(v.lastSeen) > time.Minute {
|
||||
v.attempts = 0
|
||||
}
|
||||
v.attempts++
|
||||
v.lastSeen = time.Now()
|
||||
}
|
||||
|
||||
// If more than 5 attempts within a minute, block
|
||||
if exists && v.attempts > 5 {
|
||||
mu.Unlock()
|
||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -7,7 +7,7 @@ templ Login() {
|
||||
<p class="login-subtitle">enter your credentials to continue</p>
|
||||
<form action="/login" method="POST" class="login-form">
|
||||
<div class="form-group">
|
||||
<label for="username">email</label>
|
||||
<label for="username">username / email</label>
|
||||
<input type="text" id="username" name="username" required placeholder="you@example.com"/>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
@@ -16,6 +16,35 @@ templ Login() {
|
||||
</div>
|
||||
<button type="submit" class="login-button">sign in</button>
|
||||
</form>
|
||||
<p style="margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);">
|
||||
don't have an account? <a href="/register" style="color: var(--primary);">register</a>
|
||||
</p>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
|
||||
templ Register() {
|
||||
@Layout("Register") {
|
||||
<div class="login-container">
|
||||
<h2>register</h2>
|
||||
<p class="login-subtitle">create a new account to track anime</p>
|
||||
<form action="/register" method="POST" class="login-form">
|
||||
<div class="form-group">
|
||||
<label for="username">username / email</label>
|
||||
<input type="text" id="username" name="username" required placeholder="you@example.com"/>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="password">password</label>
|
||||
<input type="password" id="password" name="password" required placeholder="minimum 12 chars"/>
|
||||
</div>
|
||||
<p style="font-size: 0.75rem; color: var(--text-muted); margin-bottom: 1rem; line-height: 1.4;">
|
||||
Password must be at least 12 characters and include an uppercase letter, lowercase letter, number, and special character.
|
||||
</p>
|
||||
<button type="submit" class="login-button">create account</button>
|
||||
</form>
|
||||
<p style="margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);">
|
||||
already have an account? <a href="/login" style="color: var(--primary);">sign in</a>
|
||||
</p>
|
||||
</div>
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func Login() templ.Component {
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<div class=\"login-container\"><h2>sign in</h2><p class=\"login-subtitle\">enter your credentials to continue</p><form action=\"/login\" method=\"POST\" class=\"login-form\"><div class=\"form-group\"><label for=\"username\">email</label> <input type=\"text\" id=\"username\" name=\"username\" required placeholder=\"you@example.com\"></div><div class=\"form-group\"><label for=\"password\">password</label> <input type=\"password\" id=\"password\" name=\"password\" required placeholder=\"your password\"></div><button type=\"submit\" class=\"login-button\">sign in</button></form></div>")
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "<div class=\"login-container\"><h2>sign in</h2><p class=\"login-subtitle\">enter your credentials to continue</p><form action=\"/login\" method=\"POST\" class=\"login-form\"><div class=\"form-group\"><label for=\"username\">username / email</label> <input type=\"text\" id=\"username\" name=\"username\" required placeholder=\"you@example.com\"></div><div class=\"form-group\"><label for=\"password\">password</label> <input type=\"password\" id=\"password\" name=\"password\" required placeholder=\"your password\"></div><button type=\"submit\" class=\"login-button\">sign in</button></form><p style=\"margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);\">don't have an account? <a href=\"/register\" style=\"color: var(--primary);\">register</a></p></div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
@@ -55,4 +55,51 @@ func Login() templ.Component {
|
||||
})
|
||||
}
|
||||
|
||||
func Register() templ.Component {
|
||||
return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil {
|
||||
return templ_7745c5c3_CtxErr
|
||||
}
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Var3 := templ.GetChildren(ctx)
|
||||
if templ_7745c5c3_Var3 == nil {
|
||||
templ_7745c5c3_Var3 = templ.NopComponent
|
||||
}
|
||||
ctx = templ.ClearChildren(ctx)
|
||||
templ_7745c5c3_Var4 := templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) {
|
||||
templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context
|
||||
templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W)
|
||||
if !templ_7745c5c3_IsBuffer {
|
||||
defer func() {
|
||||
templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err == nil {
|
||||
templ_7745c5c3_Err = templ_7745c5c3_BufErr
|
||||
}
|
||||
}()
|
||||
}
|
||||
ctx = templ.InitializeContext(ctx)
|
||||
templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 2, "<div class=\"login-container\"><h2>register</h2><p class=\"login-subtitle\">create a new account to track anime</p><form action=\"/register\" method=\"POST\" class=\"login-form\"><div class=\"form-group\"><label for=\"username\">username / email</label> <input type=\"text\" id=\"username\" name=\"username\" required placeholder=\"you@example.com\"></div><div class=\"form-group\"><label for=\"password\">password</label> <input type=\"password\" id=\"password\" name=\"password\" required placeholder=\"minimum 12 chars\"></div><p style=\"font-size: 0.75rem; color: var(--text-muted); margin-bottom: 1rem; line-height: 1.4;\">Password must be at least 12 characters and include an uppercase letter, lowercase letter, number, and special character.</p><button type=\"submit\" class=\"login-button\">create account</button></form><p style=\"margin-top: 1rem; text-align: center; color: var(--text-muted); font-size: var(--text-sm);\">already have an account? <a href=\"/login\" style=\"color: var(--primary);\">sign in</a></p></div>")
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
templ_7745c5c3_Err = Layout("Register").Render(templ.WithChildren(ctx, templ_7745c5c3_Var4), templ_7745c5c3_Buffer)
|
||||
if templ_7745c5c3_Err != nil {
|
||||
return templ_7745c5c3_Err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
var _ = templruntime.GeneratedTemplate
|
||||
|
||||
Reference in New Issue
Block a user