feat: add API token authentication

This commit is contained in:
2026-05-19 02:46:47 +02:00
parent ccfb469299
commit 237b5f3004
10 changed files with 310 additions and 14 deletions

View File

@@ -20,6 +20,7 @@ func (h *AuthHandler) Register(r *gin.Engine) {
r.GET("/login", h.HandleLoginPage)
r.POST("/login", h.HandleLogin)
r.GET("/logout", h.HandleLogout)
r.POST("/api/auth/login", h.HandleAPILogin)
}
func (h *AuthHandler) HandleLoginPage(c *gin.Context) {
@@ -59,3 +60,35 @@ func (h *AuthHandler) HandleLogout(c *gin.Context) {
c.SetCookie("session_id", "", -1, "/", "", false, true)
c.Redirect(http.StatusSeeOther, "/login")
}
func (h *AuthHandler) HandleAPILogin(c *gin.Context) {
var body struct {
Username string `json:"username"`
Password string `json:"password"`
Name string `json:"name"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Username == "" || body.Password == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"})
return
}
token, user, err := h.svc.LoginForAPIToken(c.Request.Context(), body.Username, body.Password, body.Name)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid username or password"})
return
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"user": gin.H{
"id": user.ID,
"username": user.Username,
"avatarUrl": func() string {
if user.AvatarUrl == "" {
return ""
}
return user.AvatarUrl
}(),
},
})
}

View File

@@ -3,32 +3,59 @@ package middleware
import (
"mal/internal/domain"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func AuthMiddleware(svc domain.AuthService) gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
// Allow access to login, logout and static assets without authentication
if c.Request.URL.Path == "/login" || c.Request.URL.Path == "/logout" ||
len(c.Request.URL.Path) >= 7 && c.Request.URL.Path[:7] == "/static" ||
len(c.Request.URL.Path) >= 5 && c.Request.URL.Path[:5] == "/dist" {
if path == "/login" || path == "/logout" ||
strings.HasPrefix(path, "/static") ||
strings.HasPrefix(path, "/dist") ||
path == "/api/auth/login" {
c.Next()
return
}
sessionID, err := c.Cookie("session_id")
if err != nil {
c.Redirect(http.StatusSeeOther, "/login")
c.Abort()
return
}
var user *domain.User
var err error
user, err := svc.ValidateSession(c.Request.Context(), sessionID)
if err != nil {
c.Redirect(http.StatusSeeOther, "/login")
c.Abort()
return
// API routes can authenticate via Bearer token OR cookie session.
if strings.HasPrefix(path, "/api/") {
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if strings.HasPrefix(strings.ToLower(authHeader), "bearer ") {
token := strings.TrimSpace(authHeader[7:])
user, err = svc.ValidateAPIToken(c.Request.Context(), token)
} else if sessionID, cookieErr := c.Cookie("session_id"); cookieErr == nil {
user, err = svc.ValidateSession(c.Request.Context(), sessionID)
} else {
err = cookieErr
}
if err != nil || user == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
} else {
// Non-API routes only use cookie sessions and redirect to /login.
sessionID, cookieErr := c.Cookie("session_id")
if cookieErr != nil {
c.Redirect(http.StatusSeeOther, "/login")
c.Abort()
return
}
user, err = svc.ValidateSession(c.Request.Context(), sessionID)
if err != nil || user == nil {
c.Redirect(http.StatusSeeOther, "/login")
c.Abort()
return
}
}
c.Set("User", user)

View File

@@ -7,6 +7,8 @@ import (
"mal/internal/db"
"mal/internal/domain"
"time"
"github.com/google/uuid"
)
type authRepository struct {
@@ -65,3 +67,35 @@ func (r *authRepository) GetSession(ctx context.Context, sessionID string) (*dom
func (r *authRepository) DeleteSession(ctx context.Context, sessionID string) error {
return r.queries.DeleteSession(ctx, sessionID)
}
func (r *authRepository) CreateAPIToken(ctx context.Context, userID, tokenHash, name string) (*domain.APIToken, error) {
t, err := r.queries.CreateAPIToken(ctx, db.CreateAPITokenParams{
ID: uuid.New().String(),
UserID: userID,
TokenHash: tokenHash,
Name: name,
})
if err != nil {
return nil, err
}
return &t, nil
}
func (r *authRepository) GetAPITokenByHash(ctx context.Context, tokenHash string) (*domain.APIToken, error) {
t, err := r.queries.GetAPITokenByHash(ctx, tokenHash)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
return &t, nil
}
func (r *authRepository) TouchAPITokenLastUsedAt(ctx context.Context, tokenID string) error {
return r.queries.TouchAPITokenLastUsedAt(ctx, tokenID)
}
func (r *authRepository) RevokeAllAPITokensForUser(ctx context.Context, userID string) error {
return r.queries.RevokeAllAPITokensForUser(ctx, userID)
}

View File

@@ -2,8 +2,13 @@ package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"mal/internal/domain"
"strings"
"time"
"github.com/google/uuid"
@@ -35,6 +40,32 @@ func (s *authService) Login(ctx context.Context, username, password string) (*do
return s.repo.CreateSession(ctx, user.ID, sessionID)
}
func (s *authService) LoginForAPIToken(ctx context.Context, username, password, name string) (string, *domain.User, error) {
user, err := s.repo.GetUserByUsername(ctx, username)
if err != nil {
return "", nil, err
}
if user == nil {
return "", nil, errors.New("invalid credentials")
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil {
return "", nil, errors.New("invalid credentials")
}
trimmedName := strings.TrimSpace(name)
if trimmedName == "" {
trimmedName = "Firefox extension"
}
rawToken, tokenHash := newOpaqueToken()
if _, err := s.repo.CreateAPIToken(ctx, user.ID, tokenHash, trimmedName); err != nil {
return "", nil, err
}
return rawToken, user, nil
}
func (s *authService) ValidateSession(ctx context.Context, sessionID string) (*domain.User, error) {
session, err := s.repo.GetSession(ctx, sessionID)
if err != nil {
@@ -52,6 +83,44 @@ func (s *authService) ValidateSession(ctx context.Context, sessionID string) (*d
return s.repo.GetUserByID(ctx, session.UserID)
}
func (s *authService) ValidateAPIToken(ctx context.Context, token string) (*domain.User, error) {
trimmed := strings.TrimSpace(token)
if trimmed == "" {
return nil, errors.New("token missing")
}
sum := sha256.Sum256([]byte(trimmed))
tokenHash := hex.EncodeToString(sum[:])
t, err := s.repo.GetAPITokenByHash(ctx, tokenHash)
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("token not found")
}
_ = s.repo.TouchAPITokenLastUsedAt(ctx, t.ID)
return s.repo.GetUserByID(ctx, t.UserID)
}
func (s *authService) Logout(ctx context.Context, sessionID string) error {
return s.repo.DeleteSession(ctx, sessionID)
}
func (s *authService) RevokeAllAPITokensForUser(ctx context.Context, userID string) error {
if strings.TrimSpace(userID) == "" {
return errors.New("user id missing")
}
return s.repo.RevokeAllAPITokensForUser(ctx, userID)
}
func newOpaqueToken() (token string, tokenHash string) {
buf := make([]byte, 32)
_, _ = rand.Read(buf)
token = base64.RawURLEncoding.EncodeToString(buf)
sum := sha256.Sum256([]byte(token))
tokenHash = hex.EncodeToString(sum[:])
return token, tokenHash
}

View File

@@ -0,0 +1,15 @@
-- +goose Up
CREATE TABLE IF NOT EXISTS api_token (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES user(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
revoked_at DATETIME
);
CREATE INDEX IF NOT EXISTS idx_api_token_user_id ON api_token(user_id);
-- +goose Down
DROP TABLE IF EXISTS api_token;

View File

@@ -37,6 +37,16 @@ type AnimeRelation struct {
RelationType string `json:"relation_type"`
}
type ApiToken struct {
ID string `json:"id"`
UserID string `json:"user_id"`
TokenHash string `json:"token_hash"`
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt sql.NullTime `json:"last_used_at"`
RevokedAt sql.NullTime `json:"revoked_at"`
}
type ContinueWatchingEntry struct {
ID string `json:"id"`
UserID string `json:"user_id"`

View File

@@ -10,6 +10,7 @@ import (
type Querier interface {
CountPendingAnimeFetchRetries(ctx context.Context) (int64, error)
CreateAPIToken(ctx context.Context, arg CreateAPITokenParams) (ApiToken, error)
CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error)
DeleteAnimeFetchRetry(ctx context.Context, animeID int64) error
DeleteContinueWatchingEntry(ctx context.Context, arg DeleteContinueWatchingEntryParams) error
@@ -17,6 +18,7 @@ type Querier interface {
DeleteSession(ctx context.Context, id string) error
DeleteWatchListEntry(ctx context.Context, arg DeleteWatchListEntryParams) error
EnqueueAnimeFetchRetry(ctx context.Context, arg EnqueueAnimeFetchRetryParams) error
GetAPITokenByHash(ctx context.Context, tokenHash string) (ApiToken, error)
GetAnime(ctx context.Context, id int64) (Anime, error)
GetAnimeNeedingRelationSync(ctx context.Context) ([]GetAnimeNeedingRelationSyncRow, error)
GetContinueWatchingEntries(ctx context.Context, userID string) ([]GetContinueWatchingEntriesRow, error)
@@ -37,8 +39,10 @@ type Querier interface {
MarkAnimeFetchRetryFailed(ctx context.Context, arg MarkAnimeFetchRetryFailedParams) error
MarkEpisodeAvailabilityRefreshFailed(ctx context.Context, arg MarkEpisodeAvailabilityRefreshFailedParams) error
MarkRelationsSynced(ctx context.Context, id int64) error
RevokeAllAPITokensForUser(ctx context.Context, userID string) error
SaveWatchProgress(ctx context.Context, arg SaveWatchProgressParams) error
SetJikanCache(ctx context.Context, arg SetJikanCacheParams) error
TouchAPITokenLastUsedAt(ctx context.Context, id string) error
UpdateAnimeStatus(ctx context.Context, arg UpdateAnimeStatusParams) error
UpsertAnime(ctx context.Context, arg UpsertAnimeParams) (Anime, error)
UpsertAnimeRelation(ctx context.Context, arg UpsertAnimeRelationParams) error

View File

@@ -15,6 +15,26 @@ SELECT * FROM session WHERE id = ? LIMIT 1;
-- name: DeleteSession :exec
DELETE FROM session WHERE id = ?;
-- name: CreateAPIToken :one
INSERT INTO api_token (id, user_id, token_hash, name)
VALUES (?, ?, ?, ?)
RETURNING *;
-- name: GetAPITokenByHash :one
SELECT * FROM api_token
WHERE token_hash = ? AND revoked_at IS NULL
LIMIT 1;
-- name: TouchAPITokenLastUsedAt :exec
UPDATE api_token
SET last_used_at = CURRENT_TIMESTAMP
WHERE id = ?;
-- name: RevokeAllAPITokensForUser :exec
UPDATE api_token
SET revoked_at = CURRENT_TIMESTAMP
WHERE user_id = ? AND revoked_at IS NULL;
-- name: UpsertAnime :one
INSERT INTO anime (id, title_original, title_english, title_japanese, image_url, airing, duration_seconds)
VALUES (?, ?, ?, ?, ?, ?, ?)

View File

@@ -24,6 +24,39 @@ func (q *Queries) CountPendingAnimeFetchRetries(ctx context.Context) (int64, err
return count, err
}
const createAPIToken = `-- name: CreateAPIToken :one
INSERT INTO api_token (id, user_id, token_hash, name)
VALUES (?, ?, ?, ?)
RETURNING id, user_id, token_hash, name, created_at, last_used_at, revoked_at
`
type CreateAPITokenParams struct {
ID string `json:"id"`
UserID string `json:"user_id"`
TokenHash string `json:"token_hash"`
Name string `json:"name"`
}
func (q *Queries) CreateAPIToken(ctx context.Context, arg CreateAPITokenParams) (ApiToken, error) {
row := q.db.QueryRowContext(ctx, createAPIToken,
arg.ID,
arg.UserID,
arg.TokenHash,
arg.Name,
)
var i ApiToken
err := row.Scan(
&i.ID,
&i.UserID,
&i.TokenHash,
&i.Name,
&i.CreatedAt,
&i.LastUsedAt,
&i.RevokedAt,
)
return i, err
}
const createSession = `-- name: CreateSession :one
INSERT INTO session (id, user_id, expires_at)
VALUES (?, ?, ?)
@@ -128,6 +161,27 @@ func (q *Queries) EnqueueAnimeFetchRetry(ctx context.Context, arg EnqueueAnimeFe
return err
}
const getAPITokenByHash = `-- name: GetAPITokenByHash :one
SELECT id, user_id, token_hash, name, created_at, last_used_at, revoked_at FROM api_token
WHERE token_hash = ? AND revoked_at IS NULL
LIMIT 1
`
func (q *Queries) GetAPITokenByHash(ctx context.Context, tokenHash string) (ApiToken, error) {
row := q.db.QueryRowContext(ctx, getAPITokenByHash, tokenHash)
var i ApiToken
err := row.Scan(
&i.ID,
&i.UserID,
&i.TokenHash,
&i.Name,
&i.CreatedAt,
&i.LastUsedAt,
&i.RevokedAt,
)
return i, err
}
const getAnime = `-- name: GetAnime :one
SELECT id, title_original, image_url, created_at, title_english, title_japanese, airing, status, relations_synced_at, duration_seconds FROM anime WHERE id = ? LIMIT 1
`
@@ -820,6 +874,17 @@ func (q *Queries) MarkRelationsSynced(ctx context.Context, id int64) error {
return err
}
const revokeAllAPITokensForUser = `-- name: RevokeAllAPITokensForUser :exec
UPDATE api_token
SET revoked_at = CURRENT_TIMESTAMP
WHERE user_id = ? AND revoked_at IS NULL
`
func (q *Queries) RevokeAllAPITokensForUser(ctx context.Context, userID string) error {
_, err := q.db.ExecContext(ctx, revokeAllAPITokensForUser, userID)
return err
}
const saveWatchProgress = `-- name: SaveWatchProgress :exec
UPDATE watch_list_entry
SET current_episode = ?,
@@ -864,6 +929,17 @@ func (q *Queries) SetJikanCache(ctx context.Context, arg SetJikanCacheParams) er
return err
}
const touchAPITokenLastUsedAt = `-- name: TouchAPITokenLastUsedAt :exec
UPDATE api_token
SET last_used_at = CURRENT_TIMESTAMP
WHERE id = ?
`
func (q *Queries) TouchAPITokenLastUsedAt(ctx context.Context, id string) error {
_, err := q.db.ExecContext(ctx, touchAPITokenLastUsedAt, id)
return err
}
const updateAnimeStatus = `-- name: UpdateAnimeStatus :exec
UPDATE anime SET status = ? WHERE id = ?
`

View File

@@ -7,11 +7,15 @@ import (
type User = db.User
type Session = db.Session
type APIToken = db.ApiToken
type AuthService interface {
Login(ctx context.Context, username, password string) (*Session, error)
LoginForAPIToken(ctx context.Context, username, password, name string) (token string, user *User, err error)
ValidateSession(ctx context.Context, sessionID string) (*User, error)
ValidateAPIToken(ctx context.Context, token string) (*User, error)
Logout(ctx context.Context, sessionID string) error
RevokeAllAPITokensForUser(ctx context.Context, userID string) error
}
type AuthRepository interface {
@@ -20,4 +24,8 @@ type AuthRepository interface {
CreateSession(ctx context.Context, userID string, sessionID string) (*Session, error)
GetSession(ctx context.Context, sessionID string) (*Session, error)
DeleteSession(ctx context.Context, sessionID string) error
CreateAPIToken(ctx context.Context, userID, tokenHash, name string) (*APIToken, error)
GetAPITokenByHash(ctx context.Context, tokenHash string) (*APIToken, error)
TouchAPITokenLastUsedAt(ctx context.Context, tokenID string) error
RevokeAllAPITokensForUser(ctx context.Context, userID string) error
}