diff --git a/internal/database/db.go b/internal/database/db.go index 37a578c..85d4b8c 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -10,10 +10,10 @@ import ( ) type DBTX interface { - ExecContext(context.Context, string, ...any) (sql.Result, error) + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) PrepareContext(context.Context, string) (*sql.Stmt, error) - QueryContext(context.Context, string, ...any) (*sql.Rows, error) - QueryRowContext(context.Context, string, ...any) *sql.Row + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row } func New(db DBTX) *Queries { diff --git a/internal/database/models.go b/internal/database/models.go index 34cd7d0..f1990d9 100644 --- a/internal/database/models.go +++ b/internal/database/models.go @@ -57,10 +57,11 @@ type Session struct { } type User struct { - ID string `json:"id"` - Username string `json:"username"` - PasswordHash string `json:"password_hash"` - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` + Username string `json:"username"` + PasswordHash string `json:"password_hash"` + CreatedAt time.Time `json:"created_at"` + RecoveryKeyHash string `json:"recovery_key_hash"` } type WatchListEntry struct { diff --git a/internal/database/querier.go b/internal/database/querier.go index 4d94fcc..4e2b487 100644 --- a/internal/database/querier.go +++ b/internal/database/querier.go @@ -22,12 +22,14 @@ type Querier interface { GetUpcomingSeasons(ctx context.Context, userID string) ([]GetUpcomingSeasonsRow, error) GetUser(ctx context.Context, id string) (User, error) GetUserByUsername(ctx context.Context, username string) (User, error) + GetUserByUsernameAndRecoveryKeyHash(ctx context.Context, arg GetUserByUsernameAndRecoveryKeyHashParams) (User, error) GetUserWatchList(ctx context.Context, userID string) ([]GetUserWatchListRow, error) GetWatchListEntry(ctx context.Context, arg GetWatchListEntryParams) (WatchListEntry, error) GetWatchingAnime(ctx context.Context, userID string) ([]GetWatchingAnimeRow, error) MarkRelationsSynced(ctx context.Context, id int64) error SetJikanCache(ctx context.Context, arg SetJikanCacheParams) error UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusParams) error + UpdateUserPasswordAndRecoveryKeyHash(ctx context.Context, arg UpdateUserPasswordAndRecoveryKeyHashParams) error UpsertAnime(ctx context.Context, arg UpsertAnimeParams) (Anime, error) UpsertAnimeRelation(ctx context.Context, arg UpsertAnimeRelationParams) error UpsertWatchListEntry(ctx context.Context, arg UpsertWatchListEntryParams) (WatchListEntry, error) diff --git a/internal/database/queries.sql b/internal/database/queries.sql index 293f698..3447a41 100644 --- a/internal/database/queries.sql +++ b/internal/database/queries.sql @@ -5,10 +5,18 @@ SELECT * FROM user WHERE id = ? LIMIT 1; SELECT * FROM user WHERE username = ? LIMIT 1; -- name: CreateUser :one -INSERT INTO user (id, username, password_hash) -VALUES (?, ?, ?) +INSERT INTO user (id, username, password_hash, recovery_key_hash) +VALUES (?, ?, ?, ?) RETURNING *; +-- name: GetUserByUsernameAndRecoveryKeyHash :one +SELECT * FROM user WHERE username = ? AND recovery_key_hash = ? LIMIT 1; + +-- name: UpdateUserPasswordAndRecoveryKeyHash :exec +UPDATE user +SET password_hash = ?, recovery_key_hash = ? +WHERE id = ?; + -- name: CreateSession :one INSERT INTO session (id, user_id, expires_at) VALUES (?, ?, ?) diff --git a/internal/database/queries.sql.go b/internal/database/queries.sql.go index e2bd9cc..031e336 100644 --- a/internal/database/queries.sql.go +++ b/internal/database/queries.sql.go @@ -36,25 +36,32 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S } const createUser = `-- name: CreateUser :one -INSERT INTO user (id, username, password_hash) -VALUES (?, ?, ?) -RETURNING id, username, password_hash, created_at +INSERT INTO user (id, username, password_hash, recovery_key_hash) +VALUES (?, ?, ?, ?) +RETURNING id, username, password_hash, created_at, recovery_key_hash ` type CreateUserParams struct { - ID string `json:"id"` - Username string `json:"username"` - PasswordHash string `json:"password_hash"` + ID string `json:"id"` + Username string `json:"username"` + PasswordHash string `json:"password_hash"` + RecoveryKeyHash string `json:"recovery_key_hash"` } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, arg.ID, arg.Username, arg.PasswordHash) + row := q.db.QueryRowContext(ctx, createUser, + arg.ID, + arg.Username, + arg.PasswordHash, + arg.RecoveryKeyHash, + ) var i User err := row.Scan( &i.ID, &i.Username, &i.PasswordHash, &i.CreatedAt, + &i.RecoveryKeyHash, ) return i, err } @@ -289,7 +296,7 @@ func (q *Queries) GetUpcomingSeasons(ctx context.Context, userID string) ([]GetU } const getUser = `-- name: GetUser :one -SELECT id, username, password_hash, created_at FROM user WHERE id = ? LIMIT 1 +SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE id = ? LIMIT 1 ` func (q *Queries) GetUser(ctx context.Context, id string) (User, error) { @@ -300,12 +307,13 @@ func (q *Queries) GetUser(ctx context.Context, id string) (User, error) { &i.Username, &i.PasswordHash, &i.CreatedAt, + &i.RecoveryKeyHash, ) return i, err } const getUserByUsername = `-- name: GetUserByUsername :one -SELECT id, username, password_hash, created_at FROM user WHERE username = ? LIMIT 1 +SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE username = ? LIMIT 1 ` func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) { @@ -316,6 +324,29 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, &i.Username, &i.PasswordHash, &i.CreatedAt, + &i.RecoveryKeyHash, + ) + return i, err +} + +const getUserByUsernameAndRecoveryKeyHash = `-- name: GetUserByUsernameAndRecoveryKeyHash :one +SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE username = ? AND recovery_key_hash = ? LIMIT 1 +` + +type GetUserByUsernameAndRecoveryKeyHashParams struct { + Username string `json:"username"` + RecoveryKeyHash string `json:"recovery_key_hash"` +} + +func (q *Queries) GetUserByUsernameAndRecoveryKeyHash(ctx context.Context, arg GetUserByUsernameAndRecoveryKeyHashParams) (User, error) { + row := q.db.QueryRowContext(ctx, getUserByUsernameAndRecoveryKeyHash, arg.Username, arg.RecoveryKeyHash) + var i User + err := row.Scan( + &i.ID, + &i.Username, + &i.PasswordHash, + &i.CreatedAt, + &i.RecoveryKeyHash, ) return i, err } @@ -523,6 +554,23 @@ func (q *Queries) UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusPa return err } +const updateUserPasswordAndRecoveryKeyHash = `-- name: UpdateUserPasswordAndRecoveryKeyHash :exec +UPDATE user +SET password_hash = ?, recovery_key_hash = ? +WHERE id = ? +` + +type UpdateUserPasswordAndRecoveryKeyHashParams struct { + PasswordHash string `json:"password_hash"` + RecoveryKeyHash string `json:"recovery_key_hash"` + ID string `json:"id"` +} + +func (q *Queries) UpdateUserPasswordAndRecoveryKeyHash(ctx context.Context, arg UpdateUserPasswordAndRecoveryKeyHashParams) error { + _, err := q.db.ExecContext(ctx, updateUserPasswordAndRecoveryKeyHash, arg.PasswordHash, arg.RecoveryKeyHash, arg.ID) + return err +} + const upsertAnime = `-- name: UpsertAnime :one INSERT INTO anime (id, title_original, title_english, title_japanese, image_url, airing) VALUES (?, ?, ?, ?, ?, ?) diff --git a/internal/features/auth/auth.go b/internal/features/auth/auth.go index 9c19004..64ff5e2 100644 --- a/internal/features/auth/auth.go +++ b/internal/features/auth/auth.go @@ -3,8 +3,10 @@ package auth import ( "context" "crypto/rand" + "crypto/sha256" "database/sql" "encoding/base64" + "encoding/hex" "errors" "fmt" "net/http" @@ -23,6 +25,7 @@ var ( ErrUserExists = errors.New("username already exists") ErrNotAuthenticated = errors.New("not authenticated") ErrInvalidPassword = errors.New("password does not meet security requirements") + ErrInvalidRecoveryKey = errors.New("invalid recovery details") ) type Service struct { @@ -33,15 +36,27 @@ func NewService(db database.Querier) *Service { return &Service{db: db} } -// generateSessionToken generates a secure random 32-byte token. -func generateSessionToken() (string, error) { - b := make([]byte, 32) +func generateToken(size int) (string, error) { + b := make([]byte, size) if _, err := rand.Read(b); err != nil { return "", err } return base64.URLEncoding.EncodeToString(b), nil } +func generateSessionToken() (string, error) { + return generateToken(32) +} + +func generateRecoveryKey() (string, error) { + return generateToken(24) +} + +func hashRecoveryKey(recoveryKey string) string { + sum := sha256.Sum256([]byte(recoveryKey)) + return hex.EncodeToString(sum[:]) +} + func ValidatePassword(password string) error { if len(password) < 12 { return fmt.Errorf("password must be at least 12 characters long") @@ -67,28 +82,135 @@ func ValidatePassword(password string) error { return nil } -func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, error) { +func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, string, error) { if err := ValidatePassword(password); err != nil { - return nil, fmt.Errorf("%w: %v", ErrInvalidPassword, err) + return nil, "", fmt.Errorf("%w: %v", ErrInvalidPassword, err) } hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) // higher cost if err != nil { - return nil, fmt.Errorf("failed to hash password: %w", err) + return nil, "", fmt.Errorf("failed to hash password: %w", err) + } + + recoveryKey, err := generateRecoveryKey() + if err != nil { + return nil, "", fmt.Errorf("failed to generate recovery key: %w", err) } id := uuid.New().String() user, err := s.db.CreateUser(ctx, database.CreateUserParams{ - ID: id, - Username: username, - PasswordHash: string(hash), + ID: id, + Username: username, + PasswordHash: string(hash), + RecoveryKeyHash: hashRecoveryKey(recoveryKey), }) if err != nil { // Assuming unique constraint failure for username - return nil, ErrUserExists + return nil, "", ErrUserExists } - return &user, nil + return &user, recoveryKey, nil +} + +func (s *Service) RecoverAccount(ctx context.Context, username, recoveryKey, newPassword string) (string, error) { + if err := ValidatePassword(newPassword); err != nil { + return "", fmt.Errorf("%w: %v", ErrInvalidPassword, err) + } + + user, err := s.db.GetUserByUsernameAndRecoveryKeyHash(ctx, database.GetUserByUsernameAndRecoveryKeyHashParams{ + Username: username, + RecoveryKeyHash: hashRecoveryKey(recoveryKey), + }) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return "", ErrInvalidRecoveryKey + } + return "", fmt.Errorf("failed to lookup user for recovery: %w", err) + } + + newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) + if err != nil { + return "", fmt.Errorf("failed to hash new password: %w", err) + } + + newRecoveryKey, err := generateRecoveryKey() + if err != nil { + return "", fmt.Errorf("failed to generate new recovery key: %w", err) + } + + err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{ + ID: user.ID, + PasswordHash: string(newPasswordHash), + RecoveryKeyHash: hashRecoveryKey(newRecoveryKey), + }) + if err != nil { + return "", fmt.Errorf("failed to update recovered account: %w", err) + } + + err = s.db.DeleteUserSessions(ctx, user.ID) + if err != nil { + return "", fmt.Errorf("failed to clear existing sessions: %w", err) + } + + return newRecoveryKey, nil +} + +func (s *Service) ChangePassword(ctx context.Context, userID, currentPassword, newPassword string) error { + if err := ValidatePassword(newPassword); err != nil { + return fmt.Errorf("%w: %v", ErrInvalidPassword, err) + } + + user, err := s.db.GetUser(ctx, userID) + if err != nil { + return fmt.Errorf("failed to lookup user: %w", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(currentPassword)); err != nil { + return ErrInvalidCredentials + } + + newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12) + if err != nil { + return fmt.Errorf("failed to hash new password: %w", err) + } + + err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{ + ID: user.ID, + PasswordHash: string(newPasswordHash), + RecoveryKeyHash: user.RecoveryKeyHash, + }) + if err != nil { + return fmt.Errorf("failed to update password: %w", err) + } + + return nil +} + +func (s *Service) RegenerateRecoveryKey(ctx context.Context, userID, password string) (string, error) { + user, err := s.db.GetUser(ctx, userID) + if err != nil { + return "", fmt.Errorf("failed to lookup user: %w", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + return "", ErrInvalidCredentials + } + + newRecoveryKey, err := generateRecoveryKey() + if err != nil { + return "", fmt.Errorf("failed to generate new recovery key: %w", err) + } + + err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{ + ID: user.ID, + PasswordHash: user.PasswordHash, + RecoveryKeyHash: hashRecoveryKey(newRecoveryKey), + }) + if err != nil { + return "", fmt.Errorf("failed to rotate recovery key: %w", err) + } + + return newRecoveryKey, nil } func (s *Service) Login(ctx context.Context, username, password string) (*database.Session, error) { diff --git a/internal/features/auth/handler.go b/internal/features/auth/handler.go index 8e228a3..8a0a3bf 100644 --- a/internal/features/auth/handler.go +++ b/internal/features/auth/handler.go @@ -1,8 +1,11 @@ package auth import ( + "errors" "net/http" + "time" + "mal/internal/database" "mal/internal/templates" ) @@ -10,6 +13,40 @@ type Handler struct { authService *Service } +const rateLimitFormError = "Too many attempts in a short time. Please wait a minute and try again." + +const ( + accountPasswordChangedMessage = "Password updated successfully." + accountRecoveryKeyRotatedMessage = "Recovery key rotated. Save this new key now." + accountPasswordErrorMessage = "Unable to update password with those details." + accountRecoveryErrorMessage = "Unable to rotate recovery key with those details." + accountUnexpectedErrorMessage = "Something went wrong. Please try again." + accountMissingFieldsErrorMessage = "Please complete all required fields." + accountPasswordMismatchErrorMessage = "New password and confirm password must match." +) + +func (h *Handler) accountUserFromRequest(r *http.Request) (*database.User, bool) { + cookie, err := r.Cookie("session_id") + if err != nil { + return nil, false + } + + user, err := h.authService.ValidateSession(r.Context(), cookie.Value) + if err != nil { + return nil, false + } + + return user, true +} + +func accountCreatedAt(createdAt time.Time) string { + return createdAt.Local().Format("Jan 2, 2006 at 15:04") +} + +func renderAccountPage(w http.ResponseWriter, r *http.Request, user *database.User, passwordError string, passwordSuccess string, recoveryError string, recoverySuccess string, recoveryKey string) { + templates.Account(user.Username, accountCreatedAt(user.CreatedAt), passwordError, passwordSuccess, recoveryError, recoverySuccess, recoveryKey).Render(r.Context(), w) +} + func NewHandler(authService *Service) *Handler { return &Handler{authService: authService} } @@ -18,17 +55,21 @@ func NewHandler(authService *Service) *Handler { func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { - http.Error(w, "invalid request", http.StatusBadRequest) + templates.Login("Something went wrong. Please try again.", "").Render(r.Context(), w) return } username := r.FormValue("username") password := r.FormValue("password") + if username == "" || password == "" { + templates.Login("The email or password is wrong.", username).Render(r.Context(), w) + return + } + session, err := h.authService.Login(r.Context(), username, password) if err != nil { - // Just handle generically for now, perhaps via HTMX toast - http.Error(w, "invalid credentials", http.StatusUnauthorized) + templates.Login("The email or password is wrong.", username).Render(r.Context(), w) return } @@ -41,7 +82,7 @@ func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) { func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { - http.Error(w, "invalid request", http.StatusBadRequest) + templates.Register("Something went wrong. Please try again.", "").Render(r.Context(), w) return } @@ -49,21 +90,17 @@ func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) { password := r.FormValue("password") if username == "" || password == "" { - http.Error(w, "username and password are required", http.StatusBadRequest) + templates.Register("Please enter both email and password.", username).Render(r.Context(), w) return } - _, err := h.authService.RegisterUser(r.Context(), username, password) + _, recoveryKey, err := h.authService.RegisterUser(r.Context(), username, password) if err != nil { - if err == ErrInvalidPassword { - http.Error(w, err.Error(), http.StatusBadRequest) + if errors.Is(err, ErrInvalidPassword) || errors.Is(err, ErrUserExists) { + templates.Register("Unable to create account with those details.", username).Render(r.Context(), w) return } - if err == ErrUserExists { - http.Error(w, "username already taken", http.StatusConflict) - return - } - http.Error(w, "registration failed", http.StatusInternalServerError) + templates.Register("Something went wrong. Please try again.", username).Render(r.Context(), w) return } @@ -75,9 +112,7 @@ func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) { } SetSessionCookie(w, session.ID, session.ExpiresAt) - - w.Header().Set("HX-Redirect", "/") - http.Redirect(w, r, "/", http.StatusFound) + templates.RegistrationRecoveryKey(recoveryKey).Render(r.Context(), w) } func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { @@ -92,9 +127,148 @@ func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) { } func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) { - templates.Login().Render(r.Context(), w) + formError := "" + if r.URL.Query().Get("error") == "rate_limited" { + formError = rateLimitFormError + } + templates.Login(formError, "").Render(r.Context(), w) } func (h *Handler) HandleRegisterPage(w http.ResponseWriter, r *http.Request) { - templates.Register().Render(r.Context(), w) + formError := "" + if r.URL.Query().Get("error") == "rate_limited" { + formError = rateLimitFormError + } + templates.Register(formError, "").Render(r.Context(), w) +} + +func (h *Handler) HandleRecoverPage(w http.ResponseWriter, r *http.Request) { + formError := "" + if r.URL.Query().Get("error") == "rate_limited" { + formError = rateLimitFormError + } + templates.Recover(formError, "", "").Render(r.Context(), w) +} + +func (h *Handler) HandleRecover(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + templates.Recover("Something went wrong. Please try again.", "", "").Render(r.Context(), w) + return + } + + username := r.FormValue("username") + recoveryKey := r.FormValue("recovery_key") + newPassword := r.FormValue("new_password") + + if username == "" || recoveryKey == "" || newPassword == "" { + templates.Recover("Unable to recover account with those details.", username, recoveryKey).Render(r.Context(), w) + return + } + + newRecoveryKey, err := h.authService.RecoverAccount(r.Context(), username, recoveryKey, newPassword) + if err != nil { + if errors.Is(err, ErrInvalidRecoveryKey) || errors.Is(err, ErrInvalidPassword) { + templates.Recover("Unable to recover account with those details.", username, recoveryKey).Render(r.Context(), w) + return + } + templates.Recover("Something went wrong. Please try again.", username, recoveryKey).Render(r.Context(), w) + return + } + + templates.RecoveryComplete(newRecoveryKey).Render(r.Context(), w) +} + +func (h *Handler) HandleAccountPage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + user, ok := h.accountUserFromRequest(r) + if !ok { + http.Redirect(w, r, "/login", http.StatusFound) + return + } + + renderAccountPage(w, r, user, "", "", "", "", "") +} + +func (h *Handler) HandleAccountPassword(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + user, ok := h.accountUserFromRequest(r) + if !ok { + http.Redirect(w, r, "/login", http.StatusFound) + return + } + + if err := r.ParseForm(); err != nil { + renderAccountPage(w, r, user, accountUnexpectedErrorMessage, "", "", "", "") + return + } + + currentPassword := r.FormValue("current_password") + newPassword := r.FormValue("new_password") + confirmNewPassword := r.FormValue("confirm_new_password") + + if currentPassword == "" || newPassword == "" || confirmNewPassword == "" { + renderAccountPage(w, r, user, accountMissingFieldsErrorMessage, "", "", "", "") + return + } + + if newPassword != confirmNewPassword { + renderAccountPage(w, r, user, accountPasswordMismatchErrorMessage, "", "", "", "") + return + } + + err := h.authService.ChangePassword(r.Context(), user.ID, currentPassword, newPassword) + if err != nil { + if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ErrInvalidPassword) { + renderAccountPage(w, r, user, accountPasswordErrorMessage, "", "", "", "") + return + } + renderAccountPage(w, r, user, accountUnexpectedErrorMessage, "", "", "", "") + return + } + + renderAccountPage(w, r, user, "", accountPasswordChangedMessage, "", "", "") +} + +func (h *Handler) HandleAccountRecoveryKey(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + user, ok := h.accountUserFromRequest(r) + if !ok { + http.Redirect(w, r, "/login", http.StatusFound) + return + } + + if err := r.ParseForm(); err != nil { + renderAccountPage(w, r, user, "", "", accountUnexpectedErrorMessage, "", "") + return + } + + password := r.FormValue("password") + if password == "" { + renderAccountPage(w, r, user, "", "", accountMissingFieldsErrorMessage, "", "") + return + } + + newRecoveryKey, err := h.authService.RegenerateRecoveryKey(r.Context(), user.ID, password) + if err != nil { + if errors.Is(err, ErrInvalidCredentials) { + renderAccountPage(w, r, user, "", "", accountRecoveryErrorMessage, "", "") + return + } + renderAccountPage(w, r, user, "", "", accountUnexpectedErrorMessage, "", "") + return + } + + renderAccountPage(w, r, user, "", "", "", accountRecoveryKeyRotatedMessage, newRecoveryKey) } diff --git a/internal/server/routes.go b/internal/server/routes.go index 262b907..8fd9b00 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -60,9 +60,23 @@ func NewRouter(cfg Config) http.Handler { middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRegister))).ServeHTTP(w, r) } }) + mux.HandleFunc("/recover", func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + authHandler.HandleRecoverPage(w, r) + } else { + middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRecover))).ServeHTTP(w, r) + } + }) mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) { middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogout)).ServeHTTP(w, r) }) + mux.HandleFunc("/account", authHandler.HandleAccountPage) + mux.HandleFunc("/account/password", func(w http.ResponseWriter, r *http.Request) { + middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleAccountPassword)).ServeHTTP(w, r) + }) + mux.HandleFunc("/account/recovery-key", func(w http.ResponseWriter, r *http.Request) { + middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleAccountRecoveryKey)).ServeHTTP(w, r) + }) // Watchlist Endpoints mux.HandleFunc("/api/watchlist/export", watchlistHandler.HandleExportWatchlist) diff --git a/internal/shared/middleware/auth.go b/internal/shared/middleware/auth.go index f466c0d..bbcfa07 100644 --- a/internal/shared/middleware/auth.go +++ b/internal/shared/middleware/auth.go @@ -60,7 +60,7 @@ func RequireAuth(next http.Handler) http.Handler { func RequireGlobalAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Allow unauthenticated access to login, register, search, and static files - if r.URL.Path == "/login" || r.URL.Path == "/register" || strings.HasPrefix(r.URL.Path, "/static/") || + if r.URL.Path == "/login" || r.URL.Path == "/register" || r.URL.Path == "/recover" || strings.HasPrefix(r.URL.Path, "/static/") || r.URL.Path == "/search" || r.URL.Path == "/api/search" || r.URL.Path == "/api/search-quick" || r.URL.Path == "/" { next.ServeHTTP(w, r) diff --git a/internal/shared/middleware/ratelimit.go b/internal/shared/middleware/ratelimit.go index e844327..fff305e 100644 --- a/internal/shared/middleware/ratelimit.go +++ b/internal/shared/middleware/ratelimit.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "net/http" "strings" "sync" @@ -69,6 +70,10 @@ func RateLimitAuth(next http.Handler) http.Handler { // If more than 5 attempts within a minute, block if exists && v.attempts > 5 { mu.Unlock() + if strings.HasPrefix(r.URL.Path, "/") { + http.Redirect(w, r, fmt.Sprintf("%s?error=rate_limited", r.URL.Path), http.StatusFound) + return + } http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests) return } diff --git a/migrations/008_add_recovery_key.sql b/migrations/008_add_recovery_key.sql new file mode 100644 index 0000000..1852193 --- /dev/null +++ b/migrations/008_add_recovery_key.sql @@ -0,0 +1,2 @@ +ALTER TABLE user +ADD COLUMN recovery_key_hash TEXT NOT NULL DEFAULT '';