refactor: remove recovery auth surface

This commit is contained in:
2026-04-19 19:40:18 +02:00
parent 4eaa6542ff
commit 1e1a3e8205
11 changed files with 70 additions and 291 deletions

View File

@@ -76,11 +76,10 @@ type Session struct {
}
type User struct {
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
CreatedAt time.Time `json:"created_at"`
RecoveryKeyHash string `json:"recovery_key_hash"`
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
CreatedAt time.Time `json:"created_at"`
}
type WatchListEntry struct {

View File

@@ -30,7 +30,6 @@ type Querier interface {
GetUpcomingSeasons(ctx context.Context, userID string) ([]GetUpcomingSeasonsRow, error)
GetUser(ctx context.Context, id string) (User, error)
GetUserByUsername(ctx context.Context, username string) (User, error)
GetUserByUsernameAndRecoveryKeyHash(ctx context.Context, arg GetUserByUsernameAndRecoveryKeyHashParams) (User, error)
GetUserWatchList(ctx context.Context, userID string) ([]GetUserWatchListRow, error)
GetWatchListEntry(ctx context.Context, arg GetWatchListEntryParams) (WatchListEntry, error)
GetWatchingAnime(ctx context.Context, userID string) ([]GetWatchingAnimeRow, error)
@@ -39,7 +38,6 @@ type Querier interface {
SaveWatchProgress(ctx context.Context, arg SaveWatchProgressParams) error
SetJikanCache(ctx context.Context, arg SetJikanCacheParams) error
UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusParams) error
UpdateUserPasswordAndRecoveryKeyHash(ctx context.Context, arg UpdateUserPasswordAndRecoveryKeyHashParams) error
UpsertAnime(ctx context.Context, arg UpsertAnimeParams) (Anime, error)
UpsertAnimeRelation(ctx context.Context, arg UpsertAnimeRelationParams) error
UpsertContinueWatchingEntry(ctx context.Context, arg UpsertContinueWatchingEntryParams) (ContinueWatchingEntry, error)

View File

@@ -5,18 +5,10 @@ SELECT * FROM user WHERE id = ? LIMIT 1;
SELECT * FROM user WHERE username = ? LIMIT 1;
-- name: CreateUser :one
INSERT INTO user (id, username, password_hash, recovery_key_hash)
VALUES (?, ?, ?, ?)
INSERT INTO user (id, username, password_hash)
VALUES (?, ?, ?)
RETURNING *;
-- name: GetUserByUsernameAndRecoveryKeyHash :one
SELECT * FROM user WHERE username = ? AND recovery_key_hash = ? LIMIT 1;
-- name: UpdateUserPasswordAndRecoveryKeyHash :exec
UPDATE user
SET password_hash = ?, recovery_key_hash = ?
WHERE id = ?;
-- name: CreateSession :one
INSERT INTO session (id, user_id, expires_at)
VALUES (?, ?, ?)

View File

@@ -49,32 +49,25 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
}
const createUser = `-- name: CreateUser :one
INSERT INTO user (id, username, password_hash, recovery_key_hash)
VALUES (?, ?, ?, ?)
RETURNING id, username, password_hash, created_at, recovery_key_hash
INSERT INTO user (id, username, password_hash)
VALUES (?, ?, ?)
RETURNING id, username, password_hash, created_at
`
type CreateUserParams struct {
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
RecoveryKeyHash string `json:"recovery_key_hash"`
ID string `json:"id"`
Username string `json:"username"`
PasswordHash string `json:"password_hash"`
}
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
row := q.db.QueryRowContext(ctx, createUser,
arg.ID,
arg.Username,
arg.PasswordHash,
arg.RecoveryKeyHash,
)
row := q.db.QueryRowContext(ctx, createUser, arg.ID, arg.Username, arg.PasswordHash)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
@@ -499,7 +492,7 @@ func (q *Queries) GetUpcomingSeasons(ctx context.Context, userID string) ([]GetU
}
const getUser = `-- name: GetUser :one
SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE id = ? LIMIT 1
SELECT id, username, password_hash, created_at FROM user WHERE id = ? LIMIT 1
`
func (q *Queries) GetUser(ctx context.Context, id string) (User, error) {
@@ -510,13 +503,12 @@ func (q *Queries) GetUser(ctx context.Context, id string) (User, error) {
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
const getUserByUsername = `-- name: GetUserByUsername :one
SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE username = ? LIMIT 1
SELECT id, username, password_hash, created_at FROM user WHERE username = ? LIMIT 1
`
func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User, error) {
@@ -527,29 +519,6 @@ func (q *Queries) GetUserByUsername(ctx context.Context, username string) (User,
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
const getUserByUsernameAndRecoveryKeyHash = `-- name: GetUserByUsernameAndRecoveryKeyHash :one
SELECT id, username, password_hash, created_at, recovery_key_hash FROM user WHERE username = ? AND recovery_key_hash = ? LIMIT 1
`
type GetUserByUsernameAndRecoveryKeyHashParams struct {
Username string `json:"username"`
RecoveryKeyHash string `json:"recovery_key_hash"`
}
func (q *Queries) GetUserByUsernameAndRecoveryKeyHash(ctx context.Context, arg GetUserByUsernameAndRecoveryKeyHashParams) (User, error) {
row := q.db.QueryRowContext(ctx, getUserByUsernameAndRecoveryKeyHash, arg.Username, arg.RecoveryKeyHash)
var i User
err := row.Scan(
&i.ID,
&i.Username,
&i.PasswordHash,
&i.CreatedAt,
&i.RecoveryKeyHash,
)
return i, err
}
@@ -807,23 +776,6 @@ func (q *Queries) UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusPa
return err
}
const updateUserPasswordAndRecoveryKeyHash = `-- name: UpdateUserPasswordAndRecoveryKeyHash :exec
UPDATE user
SET password_hash = ?, recovery_key_hash = ?
WHERE id = ?
`
type UpdateUserPasswordAndRecoveryKeyHashParams struct {
PasswordHash string `json:"password_hash"`
RecoveryKeyHash string `json:"recovery_key_hash"`
ID string `json:"id"`
}
func (q *Queries) UpdateUserPasswordAndRecoveryKeyHash(ctx context.Context, arg UpdateUserPasswordAndRecoveryKeyHashParams) error {
_, err := q.db.ExecContext(ctx, updateUserPasswordAndRecoveryKeyHash, arg.PasswordHash, arg.RecoveryKeyHash, arg.ID)
return err
}
const upsertAnime = `-- name: UpsertAnime :one
INSERT INTO anime (id, title_original, title_english, title_japanese, image_url, airing)
VALUES (?, ?, ?, ?, ?, ?)

View File

@@ -3,18 +3,14 @@ package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/http"
"os"
"time"
"unicode"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"mal/internal/database"
@@ -22,10 +18,7 @@ import (
var (
ErrInvalidCredentials = errors.New("invalid username or password")
ErrUserExists = errors.New("username already exists")
ErrNotAuthenticated = errors.New("not authenticated")
ErrInvalidPassword = errors.New("password does not meet security requirements")
ErrInvalidRecoveryKey = errors.New("invalid recovery details")
)
const bcryptCost = 12
@@ -50,171 +43,6 @@ func generateSessionToken() (string, error) {
return generateToken(32)
}
func generateRecoveryKey() (string, error) {
return generateToken(24)
}
func hashRecoveryKey(recoveryKey string) string {
sum := sha256.Sum256([]byte(recoveryKey))
return hex.EncodeToString(sum[:])
}
func ValidatePassword(password string) error {
if len(password) < 12 {
return fmt.Errorf("password must be at least 12 characters long")
}
var hasUpper, hasLower, hasNumber, hasSpecial bool
for _, c := range password {
switch {
case unicode.IsNumber(c):
hasNumber = true
case unicode.IsUpper(c):
hasUpper = true
case unicode.IsLower(c):
hasLower = true
case unicode.IsPunct(c) || unicode.IsSymbol(c) || !unicode.IsLetter(c) && !unicode.IsNumber(c):
hasSpecial = true
}
}
if !hasUpper || !hasLower || !hasNumber || !hasSpecial {
return fmt.Errorf("password must contain at least one uppercase letter, one lowercase letter, one number, and one special character")
}
return nil
}
func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, string, error) {
if err := ValidatePassword(password); err != nil {
return nil, "", fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
if err != nil {
return nil, "", fmt.Errorf("failed to hash password: %w", err)
}
recoveryKey, err := generateRecoveryKey()
if err != nil {
return nil, "", fmt.Errorf("failed to generate recovery key: %w", err)
}
id := uuid.New().String()
user, err := s.db.CreateUser(ctx, database.CreateUserParams{
ID: id,
Username: username,
PasswordHash: string(hash),
RecoveryKeyHash: hashRecoveryKey(recoveryKey),
})
if err != nil {
// Assuming unique constraint failure for username
return nil, "", ErrUserExists
}
return &user, recoveryKey, nil
}
func (s *Service) RecoverAccount(ctx context.Context, username, recoveryKey, newPassword string) (string, error) {
if err := ValidatePassword(newPassword); err != nil {
return "", fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
user, err := s.db.GetUserByUsernameAndRecoveryKeyHash(ctx, database.GetUserByUsernameAndRecoveryKeyHashParams{
Username: username,
RecoveryKeyHash: hashRecoveryKey(recoveryKey),
})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", ErrInvalidRecoveryKey
}
return "", fmt.Errorf("failed to lookup user for recovery: %w", err)
}
newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcryptCost)
if err != nil {
return "", fmt.Errorf("failed to hash new password: %w", err)
}
newRecoveryKey, err := generateRecoveryKey()
if err != nil {
return "", fmt.Errorf("failed to generate new recovery key: %w", err)
}
err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{
ID: user.ID,
PasswordHash: string(newPasswordHash),
RecoveryKeyHash: hashRecoveryKey(newRecoveryKey),
})
if err != nil {
return "", fmt.Errorf("failed to update recovered account: %w", err)
}
err = s.db.DeleteUserSessions(ctx, user.ID)
if err != nil {
return "", fmt.Errorf("failed to clear existing sessions: %w", err)
}
return newRecoveryKey, nil
}
func (s *Service) ChangePassword(ctx context.Context, userID, currentPassword, newPassword string) error {
if err := ValidatePassword(newPassword); err != nil {
return fmt.Errorf("%w: %v", ErrInvalidPassword, err)
}
user, err := s.db.GetUser(ctx, userID)
if err != nil {
return fmt.Errorf("failed to lookup user: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(currentPassword)); err != nil {
return ErrInvalidCredentials
}
newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcryptCost)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{
ID: user.ID,
PasswordHash: string(newPasswordHash),
RecoveryKeyHash: user.RecoveryKeyHash,
})
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
func (s *Service) RegenerateRecoveryKey(ctx context.Context, userID, password string) (string, error) {
user, err := s.db.GetUser(ctx, userID)
if err != nil {
return "", fmt.Errorf("failed to lookup user: %w", err)
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return "", ErrInvalidCredentials
}
newRecoveryKey, err := generateRecoveryKey()
if err != nil {
return "", fmt.Errorf("failed to generate new recovery key: %w", err)
}
err = s.db.UpdateUserPasswordAndRecoveryKeyHash(ctx, database.UpdateUserPasswordAndRecoveryKeyHashParams{
ID: user.ID,
PasswordHash: user.PasswordHash,
RecoveryKeyHash: hashRecoveryKey(newRecoveryKey),
})
if err != nil {
return "", fmt.Errorf("failed to rotate recovery key: %w", err)
}
return newRecoveryKey, nil
}
func (s *Service) Login(ctx context.Context, username, password string) (*database.Session, error) {
user, err := s.db.GetUserByUsername(ctx, username)
if err != nil {
@@ -246,10 +74,6 @@ func (s *Service) Login(ctx context.Context, username, password string) (*databa
return &session, nil
}
func (s *Service) Logout(ctx context.Context, sessionID string) error {
return s.db.DeleteSession(ctx, sessionID)
}
func (s *Service) ValidateSession(ctx context.Context, sessionID string) (*database.User, error) {
session, err := s.db.GetSession(ctx, sessionID)
if err != nil {
@@ -284,16 +108,3 @@ func SetSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Ti
Path: "/",
})
}
func ClearSessionCookie(w http.ResponseWriter) {
isProd := os.Getenv("ENV") == "production"
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "",
Expires: time.Unix(0, 0),
HttpOnly: true,
Secure: isProd,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
}

View File

@@ -51,17 +51,6 @@ func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", http.StatusFound)
}
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_id")
if err == nil {
_ = h.authService.Logout(r.Context(), cookie.Value)
}
ClearSessionCookie(w)
w.Header().Set("HX-Redirect", "/")
http.Redirect(w, r, "/", http.StatusFound)
}
func (h *Handler) HandleLoginPage(w http.ResponseWriter, r *http.Request) {
templates.Login(rateLimitErrorFromQuery(r), "").Render(r.Context(), w)
}

View File

@@ -71,9 +71,6 @@ func NewRouter(cfg Config) http.Handler {
middleware.RateLimitAuth(middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogin))).ServeHTTP(w, r)
}
})
mux.HandleFunc("/logout", func(w http.ResponseWriter, r *http.Request) {
middleware.VerifyOrigin(http.HandlerFunc(authHandler.HandleLogout)).ServeHTTP(w, r)
})
// Watchlist Endpoints
mux.HandleFunc("/api/watchlist/export", watchlistHandler.HandleExportWatchlist)

View File

@@ -28,8 +28,6 @@ func Auth(authService *auth.Service) func(http.Handler) http.Handler {
user, err := authService.ValidateSession(r.Context(), cookie.Value)
if err != nil {
// Invalid session, proceed as unauthenticated
// Might also want to clear the invalid cookie here
auth.ClearSessionCookie(w)
next.ServeHTTP(w, r)
return
}