auth: add recovery and account security

This commit is contained in:
2026-04-11 18:05:51 +02:00
parent 810a50c606
commit 6b83f6bde6
11 changed files with 424 additions and 48 deletions

View File

@@ -10,10 +10,10 @@ import (
)
type DBTX interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...any) *sql.Row
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {

View File

@@ -57,10 +57,11 @@ type Session struct {
}
type User struct {
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
CreatedAt time.Time `json:"created_at"`
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
CreatedAt time.Time `json:"created_at"`
RecoveryKeyHash string `json:"recovery_key_hash"`
}
type WatchListEntry struct {

View File

@@ -22,12 +22,14 @@ type Querier interface {
GetUpcomingSeasons(ctx context.Context, userID string) ([]GetUpcomingSeasonsRow, error)
GetUser(ctx context.Context, id string) (User, error)
GetUserByUsername(ctx context.Context, username string) (User, error)
GetUserByUsernameAndRecoveryKeyHash(ctx context.Context, arg GetUserByUsernameAndRecoveryKeyHashParams) (User, error)
GetUserWatchList(ctx context.Context, userID string) ([]GetUserWatchListRow, error)
GetWatchListEntry(ctx context.Context, arg GetWatchListEntryParams) (WatchListEntry, error)
GetWatchingAnime(ctx context.Context, userID string) ([]GetWatchingAnimeRow, error)
MarkRelationsSynced(ctx context.Context, id int64) error
SetJikanCache(ctx context.Context, arg SetJikanCacheParams) error
UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusParams) error
UpdateUserPasswordAndRecoveryKeyHash(ctx context.Context, arg UpdateUserPasswordAndRecoveryKeyHashParams) error
UpsertAnime(ctx context.Context, arg UpsertAnimeParams) (Anime, error)
UpsertAnimeRelation(ctx context.Context, arg UpsertAnimeRelationParams) error
UpsertWatchListEntry(ctx context.Context, arg UpsertWatchListEntryParams) (WatchListEntry, error)

View File

@@ -5,10 +5,18 @@ SELECT * FROM user WHERE id = ? LIMIT 1;
SELECT * FROM user WHERE username = ? LIMIT 1;
-- name: CreateUser :one
INSERT INTO user (id, username, password_hash)
VALUES (?, ?, ?)
INSERT INTO user (id, username, password_hash, recovery_key_hash)
VALUES (?, ?, ?, ?)
RETURNING *;
-- name: GetUserByUsernameAndRecoveryKeyHash :one
SELECT * FROM user WHERE username = ? AND recovery_key_hash = ? LIMIT 1;
-- name: UpdateUserPasswordAndRecoveryKeyHash :exec
UPDATE user
SET password_hash = ?, recovery_key_hash = ?
WHERE id = ?;
-- name: CreateSession :one
INSERT INTO session (id, user_id, expires_at)
VALUES (?, ?, ?)

View File

@@ -36,25 +36,32 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
}
const createUser = `-- name: CreateUser :one
INSERT INTO user (id, username, password_hash)
VALUES (?, ?, ?)
RETURNING id, username, password_hash, created_at
INSERT INTO user (id, username, password_hash, recovery_key_hash)
VALUES (?, ?, ?, ?)
RETURNING id, username, password_hash, created_at, recovery_key_hash
`
type CreateUserParams struct {
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
RecoveryKeyHash string `json:"recovery_key_hash"`
}
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
row := q.db.QueryRowContext(ctx, createUser, arg.ID, arg.Username, arg.PasswordHash)
row := q.db.QueryRowContext(ctx, createUser,
arg.ID,
arg.Username,
arg.PasswordHash,
arg.RecoveryKeyHash,
)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
@@ -289,7 +296,7 @@ func (q *Queries) GetUpcomingSeasons(ctx context.Context, userID string) ([]GetU
}
const getUser = `-- name: GetUser :one
SELECT id, username, password_hash, created_at FROM user WHERE id = ? LIMIT 1
SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE id = ? LIMIT 1
`
func (q *Queries) GetUser(ctx context.Context, id string) (User, error) {
@@ -300,12 +307,13 @@ func (q *Queries) GetUser(ctx context.Context, id string) (User, error) {
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
const getUserByUsername = `-- name: GetUserByUsername :one
SELECT id, username, password_hash, created_at FROM user WHERE username = ? LIMIT 1
SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE username = ? LIMIT 1
`
func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) {
@@ -316,6 +324,29 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User,
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
const getUserByUsernameAndRecoveryKeyHash = `-- name: GetUserByUsernameAndRecoveryKeyHash :one
SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE username = ? AND recovery_key_hash = ? LIMIT 1
`
type GetUserByUsernameAndRecoveryKeyHashParams struct {
Username string `json:"username"`
RecoveryKeyHash string `json:"recovery_key_hash"`
}
func (q *Queries) GetUserByUsernameAndRecoveryKeyHash(ctx context.Context, arg GetUserByUsernameAndRecoveryKeyHashParams) (User, error) {
row := q.db.QueryRowContext(ctx, getUserByUsernameAndRecoveryKeyHash, arg.Username, arg.RecoveryKeyHash)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
@@ -523,6 +554,23 @@ func (q *Queries) UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusPa
return err
}
const updateUserPasswordAndRecoveryKeyHash = `-- name: UpdateUserPasswordAndRecoveryKeyHash :exec
UPDATE user
SET password_hash = ?, recovery_key_hash = ?
WHERE id = ?
`
type UpdateUserPasswordAndRecoveryKeyHashParams struct {
PasswordHash string `json:"password_hash"`
RecoveryKeyHash string `json:"recovery_key_hash"`
ID string `json:"id"`
}
func (q *Queries) UpdateUserPasswordAndRecoveryKeyHash(ctx context.Context, arg UpdateUserPasswordAndRecoveryKeyHashParams) error {
_, err := q.db.ExecContext(ctx, updateUserPasswordAndRecoveryKeyHash, arg.PasswordHash, arg.RecoveryKeyHash, arg.ID)
return err
}
const upsertAnime = `-- name: UpsertAnime :one
INSERT INTO anime (id, title_original, title_english, title_japanese, image_url, airing)
VALUES (?, ?, ?, ?, ?, ?)

View File

@@ -3,8 +3,10 @@ package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/http"
@@ -23,6 +25,7 @@ var (
ErrUserExists = errors.New("username already exists")
ErrNotAuthenticated = errors.New("not authenticated")
ErrInvalidPassword = errors.New("password does not meet security requirements")
ErrInvalidRecoveryKey = errors.New("invalid recovery details")
)
type Service struct {
@@ -33,15 +36,27 @@ func NewService(db database.Querier) *Service {
return &Service{db: db}
}
// generateSessionToken generates a secure random 32-byte token.
func generateSessionToken() (string, error) {
b := make([]byte, 32)
func generateToken(size int) (string, error) {
b := make([]byte, size)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
func generateSessionToken() (string, error) {
return generateToken(32)
}
func generateRecoveryKey() (string, error) {
return generateToken(24)
}
func hashRecoveryKey(recoveryKey string) string {
sum := sha256.Sum256([]byte(recoveryKey))
return hex.EncodeToString(sum[:])
}
func ValidatePassword(password string) error {
if len(password) < 12 {
return fmt.Errorf("password must be at least 12 characters long")
@@ -67,28 +82,135 @@ func ValidatePassword(password string) error {
return nil
}
func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, error) {
func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, string, error) {
if err := ValidatePassword(password); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidPassword, err)
return nil, "", fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), 12) // higher cost
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
return nil, "", fmt.Errorf("failed to hash password: %w", err)
}
recoveryKey, err := generateRecoveryKey()
if err != nil {
return nil, "", fmt.Errorf("failed to generate recovery key: %w", err)
}
id := uuid.New().String()
user, err := s.db.CreateUser(ctx, database.CreateUserParams{
ID: id,
Username: username,
PasswordHash: string(hash),
ID: id,
Username: username,
PasswordHash: string(hash),
RecoveryKeyHash: hashRecoveryKey(recoveryKey),
})
if err != nil {
// Assuming unique constraint failure for username
return nil, ErrUserExists
return nil, "", ErrUserExists
}
return &user, nil
return &user, recoveryKey, nil
}
func (s *Service) RecoverAccount(ctx context.Context, username, recoveryKey, newPassword string) (string, error) {
if err := ValidatePassword(newPassword); err != nil {
return "", fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
user, err := s.db.GetUserByUsernameAndRecoveryKeyHash(ctx, database.GetUserByUsernameAndRecoveryKeyHashParams{
Username: username,
RecoveryKeyHash: hashRecoveryKey(recoveryKey),
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrInvalidRecoveryKey
}
return "", fmt.Errorf("failed to lookup user for recovery: %w", err)
}
newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12)
if err != nil {
return "", fmt.Errorf("failed to hash new password: %w", err)
}
newRecoveryKey, err := generateRecoveryKey()
if err != nil {
return "", fmt.Errorf("failed to generate new recovery key: %w", err)
}
err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{
ID: user.ID,
PasswordHash: string(newPasswordHash),
RecoveryKeyHash: hashRecoveryKey(newRecoveryKey),
})
if err != nil {
return "", fmt.Errorf("failed to update recovered account: %w", err)
}
err = s.db.DeleteUserSessions(ctx, user.ID)
if err != nil {
return "", fmt.Errorf("failed to clear existing sessions: %w", err)
}
return newRecoveryKey, nil
}
func (s *Service) ChangePassword(ctx context.Context, userID, currentPassword, newPassword string) error {
if err := ValidatePassword(newPassword); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
user, err := s.db.GetUser(ctx, userID)
if err != nil {
return fmt.Errorf("failed to lookup user: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(currentPassword)); err != nil {
return ErrInvalidCredentials
}
newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), 12)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{
ID: user.ID,
PasswordHash: string(newPasswordHash),
RecoveryKeyHash: user.RecoveryKeyHash,
})
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
func (s *Service) RegenerateRecoveryKey(ctx context.Context, userID, password string) (string, error) {
user, err := s.db.GetUser(ctx, userID)
if err != nil {
return "", fmt.Errorf("failed to lookup user: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return "", ErrInvalidCredentials
}
newRecoveryKey, err := generateRecoveryKey()
if err != nil {
return "", fmt.Errorf("failed to generate new recovery key: %w", err)
}
err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{
ID: user.ID,
PasswordHash: user.PasswordHash,
RecoveryKeyHash: hashRecoveryKey(newRecoveryKey),
})
if err != nil {
return "", fmt.Errorf("failed to rotate recovery key: %w", err)
}
return newRecoveryKey, nil
}
func (s *Service) Login(ctx context.Context, username, password string) (*database.Session, error) {

View File

@@ -1,8 +1,11 @@
package auth
import (
"errors"
"net/http"
"time"
"mal/internal/database"
"mal/internal/templates"
)
@@ -10,6 +13,40 @@ type Handler struct {
authService *Service
}
const rateLimitFormError = "Too many attempts in a short time. Please wait a minute and try again."
const (
accountPasswordChangedMessage = "Password updated successfully."
accountRecoveryKeyRotatedMessage = "Recovery key rotated. Save this new key now."
accountPasswordErrorMessage = "Unable to update password with those details."
accountRecoveryErrorMessage = "Unable to rotate recovery key with those details."
accountUnexpectedErrorMessage = "Something went wrong. Please try again."
accountMissingFieldsErrorMessage = "Please complete all required fields."
accountPasswordMismatchErrorMessage = "New password and confirm password must match."
)
func (h *Handler) accountUserFromRequest(r *http.Request) (*database.User, bool) {
cookie, err := r.Cookie("session_id")
if err != nil {
return nil, false
}
user, err := h.authService.ValidateSession(r.Context(), cookie.Value)
if err != nil {
return nil, false
}
return user, true
}
func accountCreatedAt(createdAt time.Time) string {
return createdAt.Local().Format("Jan 2, 2006 at 15:04")
}
func renderAccountPage(w http.ResponseWriter, r *http.Request, user *database.User, passwordError string, passwordSuccess string, recoveryError string, recoverySuccess string, recoveryKey string) {
templates.Account(user.Username, accountCreatedAt(user.CreatedAt), passwordError, passwordSuccess, recoveryError, recoverySuccess, recoveryKey).Render(r.Context(), w)
}
func NewHandler(authService *Service) *Handler {
return &Handler{authService: authService}
}
@@ -18,17 +55,21 @@ func NewHandler(authService *Service) *Handler {
func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
templates.Login("Something went wrong. Please try again.", "").Render(r.Context(), w)
return
}
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
templates.Login("The email or password is wrong.", username).Render(r.Context(), w)
return
}
session, err := h.authService.Login(r.Context(), username, password)
if err != nil {
// Just handle generically for now, perhaps via HTMX toast
http.Error(w, "invalid credentials", http.StatusUnauthorized)
templates.Login("The email or password is wrong.", username).Render(r.Context(), w)
return
}
@@ -41,7 +82,7 @@ func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
templates.Register("Something went wrong. Please try again.", "").Render(r.Context(), w)
return
}
@@ -49,21 +90,17 @@ func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password")
if username == "" || password == "" {
http.Error(w, "username and password are required", http.StatusBadRequest)
templates.Register("Please enter both email and password.", username).Render(r.Context(), w)
return
}
_, err := h.authService.RegisterUser(r.Context(), username, password)
_, recoveryKey, err := h.authService.RegisterUser(r.Context(), username, password)
if err != nil {
if err == ErrInvalidPassword {
http.Error(w, err.Error(), http.StatusBadRequest)
if errors.Is(err, ErrInvalidPassword) || errors.Is(err, ErrUserExists) {
templates.Register("Unable to create account with those details.", username).Render(r.Context(), w)
return
}
if err == ErrUserExists {
http.Error(w, "username already taken", http.StatusConflict)
return
}
http.Error(w, "registration failed", http.StatusInternalServerError)
templates.Register("Something went wrong. Please try again.", username).Render(r.Context(), w)
return
}
@@ -75,9 +112,7 @@ func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
}
SetSessionCookie(w, session.ID, session.ExpiresAt)
w.Header().Set("HX-Redirect", "/")
http.Redirect(w, r, "/", http.StatusFound)
templates.RegistrationRecoveryKey(recoveryKey).Render(r.Context(), w)
}
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
@@ -92,9 +127,148 @@ func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
templates.Login().Render(r.Context(), w)
formError := ""
if r.URL.Query().Get("error") == "rate_limited" {
formError = rateLimitFormError
}
templates.Login(formError, "").Render(r.Context(), w)
}
func (h *Handler) HandleRegisterPage(w http.ResponseWriter, r *http.Request) {
templates.Register().Render(r.Context(), w)
formError := ""
if r.URL.Query().Get("error") == "rate_limited" {
formError = rateLimitFormError
}
templates.Register(formError, "").Render(r.Context(), w)
}
func (h *Handler) HandleRecoverPage(w http.ResponseWriter, r *http.Request) {
formError := ""
if r.URL.Query().Get("error") == "rate_limited" {
formError = rateLimitFormError
}
templates.Recover(formError, "", "").Render(r.Context(), w)
}
func (h *Handler) HandleRecover(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
templates.Recover("Something went wrong. Please try again.", "", "").Render(r.Context(), w)
return
}
username := r.FormValue("username")
recoveryKey := r.FormValue("recovery_key")
newPassword := r.FormValue("new_password")
if username == "" || recoveryKey == "" || newPassword == "" {
templates.Recover("Unable to recover account with those details.", username, recoveryKey).Render(r.Context(), w)
return
}
newRecoveryKey, err := h.authService.RecoverAccount(r.Context(), username, recoveryKey, newPassword)
if err != nil {
if errors.Is(err, ErrInvalidRecoveryKey) || errors.Is(err, ErrInvalidPassword) {
templates.Recover("Unable to recover account with those details.", username, recoveryKey).Render(r.Context(), w)
return
}
templates.Recover("Something went wrong. Please try again.", username, recoveryKey).Render(r.Context(), w)
return
}
templates.RecoveryComplete(newRecoveryKey).Render(r.Context(), w)
}
func (h *Handler) HandleAccountPage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
user, ok := h.accountUserFromRequest(r)
if !ok {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
renderAccountPage(w, r, user, "", "", "", "", "")
}
func (h *Handler) HandleAccountPassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
user, ok := h.accountUserFromRequest(r)
if !ok {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
if err := r.ParseForm(); err != nil {
renderAccountPage(w, r, user, accountUnexpectedErrorMessage, "", "", "", "")
return
}
currentPassword := r.FormValue("current_password")
newPassword := r.FormValue("new_password")
confirmNewPassword := r.FormValue("confirm_new_password")
if currentPassword == "" || newPassword == "" || confirmNewPassword == "" {
renderAccountPage(w, r, user, accountMissingFieldsErrorMessage, "", "", "", "")
return
}
if newPassword != confirmNewPassword {
renderAccountPage(w, r, user, accountPasswordMismatchErrorMessage, "", "", "", "")
return
}
err := h.authService.ChangePassword(r.Context(), user.ID, currentPassword, newPassword)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) || errors.Is(err, ErrInvalidPassword) {
renderAccountPage(w, r, user, accountPasswordErrorMessage, "", "", "", "")
return
}
renderAccountPage(w, r, user, accountUnexpectedErrorMessage, "", "", "", "")
return
}
renderAccountPage(w, r, user, "", accountPasswordChangedMessage, "", "", "")
}
func (h *Handler) HandleAccountRecoveryKey(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
user, ok := h.accountUserFromRequest(r)
if !ok {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
if err := r.ParseForm(); err != nil {
renderAccountPage(w, r, user, "", "", accountUnexpectedErrorMessage, "", "")
return
}
password := r.FormValue("password")
if password == "" {
renderAccountPage(w, r, user, "", "", accountMissingFieldsErrorMessage, "", "")
return
}
newRecoveryKey, err := h.authService.RegenerateRecoveryKey(r.Context(), user.ID, password)
if err != nil {
if errors.Is(err, ErrInvalidCredentials) {
renderAccountPage(w, r, user, "", "", accountRecoveryErrorMessage, "", "")
return
}
renderAccountPage(w, r, user, "", "", accountUnexpectedErrorMessage, "", "")
return
}
renderAccountPage(w, r, user, "", "", "", accountRecoveryKeyRotatedMessage, newRecoveryKey)
}

View File

@@ -60,9 +60,23 @@ func NewRouter(cfg Config) http.Handler {
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRegister))).ServeHTTP(w, r)
}
})
mux.HandleFunc("/recover", func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
authHandler.HandleRecoverPage(w, r)
} else {
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleRecover))).ServeHTTP(w, r)
}
})
mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogout)).ServeHTTP(w, r)
})
mux.HandleFunc("/account", authHandler.HandleAccountPage)
mux.HandleFunc("/account/password", func(w http.ResponseWriter, r *http.Request) {
middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleAccountPassword)).ServeHTTP(w, r)
})
mux.HandleFunc("/account/recovery-key", func(w http.ResponseWriter, r *http.Request) {
middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleAccountRecoveryKey)).ServeHTTP(w, r)
})
// Watchlist Endpoints
mux.HandleFunc("/api/watchlist/export", watchlistHandler.HandleExportWatchlist)

View File

@@ -60,7 +60,7 @@ func RequireAuth(next http.Handler) http.Handler {
func RequireGlobalAuth(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Allow unauthenticated access to login, register, search, and static files
if r.URL.Path == "/login" || r.URL.Path == "/register" || strings.HasPrefix(r.URL.Path, "/static/") ||
if r.URL.Path == "/login" || r.URL.Path == "/register" || r.URL.Path == "/recover" || strings.HasPrefix(r.URL.Path, "/static/") ||
r.URL.Path == "/search" || r.URL.Path == "/api/search" || r.URL.Path == "/api/search-quick" ||
r.URL.Path == "/" {
next.ServeHTTP(w, r)

View File

@@ -1,6 +1,7 @@
package middleware
import (
"fmt"
"net/http"
"strings"
"sync"
@@ -69,6 +70,10 @@ func RateLimitAuth(next http.Handler) http.Handler {
// If more than 5 attempts within a minute, block
if exists && v.attempts > 5 {
mu.Unlock()
if strings.HasPrefix(r.URL.Path, "/") {
http.Redirect(w, r, fmt.Sprintf("%s?error=rate_limited", r.URL.Path), http.StatusFound)
return
}
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
return
}

View File

@@ -0,0 +1,2 @@
ALTER TABLE user
ADD COLUMN recovery_key_hash TEXT NOT NULL DEFAULT '';