From c5c15cdabca88da726a2c7ed619045a20f64a871 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Thu, 21 May 2026 19:09:49 +0200 Subject: [PATCH] fix: rolling session renewal --- internal/auth/handler/handler.go | 3 +-- internal/auth/middleware/middleware.go | 16 ++++++++++++++-- internal/auth/repository/repository.go | 9 ++++++++- internal/auth/service/service.go | 8 ++++++++ internal/db/queries.sql | 5 +++++ internal/db/queries.sql.go | 16 ++++++++++++++++ internal/domain/auth.go | 5 +++++ 7 files changed, 57 insertions(+), 5 deletions(-) diff --git a/internal/auth/handler/handler.go b/internal/auth/handler/handler.go index e3d0c35..a26356f 100644 --- a/internal/auth/handler/handler.go +++ b/internal/auth/handler/handler.go @@ -3,7 +3,6 @@ package handler import ( "mal/internal/domain" "net/http" - "time" "github.com/gin-gonic/gin" ) @@ -42,7 +41,7 @@ func (h *AuthHandler) HandleLogin(c *gin.Context) { return } - c.SetCookie("session_id", session.ID, int(24*time.Hour.Seconds()), "/", "", false, true) + c.SetCookie("session_id", session.ID, int(domain.SessionLifetime.Seconds()), "/", "", false, true) if c.GetHeader("HX-Request") == "true" { c.Header("HX-Redirect", "/") c.Status(http.StatusOK) diff --git a/internal/auth/middleware/middleware.go b/internal/auth/middleware/middleware.go index d5a53ef..286202b 100644 --- a/internal/auth/middleware/middleware.go +++ b/internal/auth/middleware/middleware.go @@ -23,6 +23,8 @@ func AuthMiddleware(svc domain.AuthService) gin.HandlerFunc { var user *domain.User var err error + var sessionID string + var usesCookieSession bool // API routes can authenticate via Bearer token OR cookie session. if strings.HasPrefix(path, "/api/") { @@ -30,7 +32,9 @@ func AuthMiddleware(svc domain.AuthService) gin.HandlerFunc { if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { token := strings.TrimSpace(authHeader[7:]) user, err = svc.ValidateAPIToken(c.Request.Context(), token) - } else if sessionID, cookieErr := c.Cookie("session_id"); cookieErr == nil { + } else if cookieSessionID, cookieErr := c.Cookie("session_id"); cookieErr == nil { + sessionID = cookieSessionID + usesCookieSession = true user, err = svc.ValidateSession(c.Request.Context(), sessionID) } else { err = cookieErr @@ -43,13 +47,15 @@ func AuthMiddleware(svc domain.AuthService) gin.HandlerFunc { } } else { // Non-API routes only use cookie sessions and redirect to /login. - sessionID, cookieErr := c.Cookie("session_id") + cookieSessionID, cookieErr := c.Cookie("session_id") if cookieErr != nil { c.Redirect(http.StatusSeeOther, "/login") c.Abort() return } + sessionID = cookieSessionID + usesCookieSession = true user, err = svc.ValidateSession(c.Request.Context(), sessionID) if err != nil || user == nil { c.Redirect(http.StatusSeeOther, "/login") @@ -58,6 +64,12 @@ func AuthMiddleware(svc domain.AuthService) gin.HandlerFunc { } } + if usesCookieSession { + if refreshErr := svc.RefreshSession(c.Request.Context(), sessionID); refreshErr == nil { + c.SetCookie("session_id", sessionID, int(domain.SessionLifetime.Seconds()), "/", "", false, true) + } + } + c.Set("User", user) c.Next() } diff --git a/internal/auth/repository/repository.go b/internal/auth/repository/repository.go index 10be86d..74524e4 100644 --- a/internal/auth/repository/repository.go +++ b/internal/auth/repository/repository.go @@ -45,7 +45,7 @@ func (r *authRepository) CreateSession(ctx context.Context, userID string, sessi s, err := r.queries.CreateSession(ctx, db.CreateSessionParams{ ID: sessionID, UserID: userID, - ExpiresAt: time.Now().Add(24 * time.Hour), + ExpiresAt: time.Now().Add(domain.SessionLifetime), }) if err != nil { return nil, err @@ -64,6 +64,13 @@ func (r *authRepository) GetSession(ctx context.Context, sessionID string) (*dom return &s, nil } +func (r *authRepository) RefreshSession(ctx context.Context, sessionID string, expiresAt time.Time) error { + return r.queries.RefreshSession(ctx, db.RefreshSessionParams{ + ExpiresAt: expiresAt, + ID: sessionID, + }) +} + func (r *authRepository) DeleteSession(ctx context.Context, sessionID string) error { return r.queries.DeleteSession(ctx, sessionID) } diff --git a/internal/auth/service/service.go b/internal/auth/service/service.go index e57de80..94d5354 100644 --- a/internal/auth/service/service.go +++ b/internal/auth/service/service.go @@ -83,6 +83,14 @@ func (s *authService) ValidateSession(ctx context.Context, sessionID string) (*d return s.repo.GetUserByID(ctx, session.UserID) } +func (s *authService) RefreshSession(ctx context.Context, sessionID string) error { + if strings.TrimSpace(sessionID) == "" { + return errors.New("session id missing") + } + + return s.repo.RefreshSession(ctx, sessionID, time.Now().Add(domain.SessionLifetime)) +} + func (s *authService) ValidateAPIToken(ctx context.Context, token string) (*domain.User, error) { trimmed := strings.TrimSpace(token) if trimmed == "" { diff --git a/internal/db/queries.sql b/internal/db/queries.sql index d3fbd62..07ed5f4 100644 --- a/internal/db/queries.sql +++ b/internal/db/queries.sql @@ -15,6 +15,11 @@ SELECT * FROM session WHERE id = ? LIMIT 1; -- name: DeleteSession :exec DELETE FROM session WHERE id = ?; +-- name: RefreshSession :exec +UPDATE session +SET expires_at = ? +WHERE id = ?; + -- name: CreateAPIToken :one INSERT INTO api_token (id, user_id, token_hash, name) VALUES (?, ?, ?, ?) diff --git a/internal/db/queries.sql.go b/internal/db/queries.sql.go index 61a64fe..bde28ea 100644 --- a/internal/db/queries.sql.go +++ b/internal/db/queries.sql.go @@ -124,6 +124,22 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error { return err } +const refreshSession = `-- name: RefreshSession :exec +UPDATE session +SET expires_at = ? +WHERE id = ? +` + +type RefreshSessionParams struct { + ExpiresAt time.Time `json:"expires_at"` + ID string `json:"id"` +} + +func (q *Queries) RefreshSession(ctx context.Context, arg RefreshSessionParams) error { + _, err := q.db.ExecContext(ctx, refreshSession, arg.ExpiresAt, arg.ID) + return err +} + const deleteWatchListEntry = `-- name: DeleteWatchListEntry :exec DELETE FROM watch_list_entry WHERE user_id = ? AND anime_id = ? diff --git a/internal/domain/auth.go b/internal/domain/auth.go index fc20461..71bf372 100644 --- a/internal/domain/auth.go +++ b/internal/domain/auth.go @@ -3,16 +3,20 @@ package domain import ( "context" "mal/internal/db" + "time" ) type User = db.User type Session = db.Session type APIToken = db.ApiToken +const SessionLifetime = 90 * 24 * time.Hour + type AuthService interface { Login(ctx context.Context, username, password string) (*Session, error) LoginForAPIToken(ctx context.Context, username, password, name string) (token string, user *User, err error) ValidateSession(ctx context.Context, sessionID string) (*User, error) + RefreshSession(ctx context.Context, sessionID string) error ValidateAPIToken(ctx context.Context, token string) (*User, error) Logout(ctx context.Context, sessionID string) error RevokeAllAPITokensForUser(ctx context.Context, userID string) error @@ -23,6 +27,7 @@ type AuthRepository interface { GetUserByID(ctx context.Context, id string) (*User, error) CreateSession(ctx context.Context, userID string, sessionID string) (*Session, error) GetSession(ctx context.Context, sessionID string) (*Session, error) + RefreshSession(ctx context.Context, sessionID string, expiresAt time.Time) error DeleteSession(ctx context.Context, sessionID string) error CreateAPIToken(ctx context.Context, userID, tokenHash, name string) (*APIToken, error) GetAPITokenByHash(ctx context.Context, tokenHash string) (*APIToken, error)