diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..17baa95 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,141 @@ +package auth + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "errors" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + + "malago/internal/database" +) + +var ( + ErrInvalidCredentials = errors.New("invalid username or password") + ErrUserExists = errors.New("username already exists") + ErrNotAuthenticated = errors.New("not authenticated") +) + +type Service struct { + db database.Querier +} + +func NewService(db database.Querier) *Service { + return &Service{db: db} +} + +// generateSessionToken generates a secure random 32-byte token. +func generateSessionToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil +} + +func (s *Service) RegisterUser(ctx context.Context, username, password string) (*database.User, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, fmt.Errorf("failed to hash password: %w", err) + } + + id := uuid.New().String() + user, err := s.db.CreateUser(ctx, database.CreateUserParams{ + ID: id, + Username: username, + PasswordHash: string(hash), + }) + if err != nil { + // Assuming unique constraint failure for username + return nil, ErrUserExists + } + + return &user, nil +} + +func (s *Service) Login(ctx context.Context, username, password string) (*database.Session, error) { + user, err := s.db.GetUserByUsername(ctx, username) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrInvalidCredentials + } + return nil, fmt.Errorf("failed to lookup user: %w", err) + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { + return nil, ErrInvalidCredentials + } + + token, err := generateSessionToken() + if err != nil { + return nil, fmt.Errorf("failed to generate session token: %w", err) + } + + expiresAt := time.Now().Add(30 * 24 * time.Hour) // 30 days + session, err := s.db.CreateSession(ctx, database.CreateSessionParams{ + ID: token, + UserID: user.ID, + ExpiresAt: expiresAt, + }) + if err != nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + + return &session, nil +} + +func (s *Service) Logout(ctx context.Context, sessionID string) error { + return s.db.DeleteSession(ctx, sessionID) +} + +func (s *Service) ValidateSession(ctx context.Context, sessionID string) (*database.User, error) { + session, err := s.db.GetSession(ctx, sessionID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrNotAuthenticated + } + return nil, fmt.Errorf("failed to get session: %w", err) + } + + if time.Now().After(session.ExpiresAt) { + _ = s.db.DeleteSession(ctx, sessionID) + return nil, ErrNotAuthenticated + } + + user, err := s.db.GetUser(ctx, session.UserID) + if err != nil { + return nil, fmt.Errorf("failed to get user for session: %w", err) + } + + return &user, nil +} + +func SetSessionCookie(w http.ResponseWriter, sessionID string, expiresAt time.Time) { + http.SetCookie(w, &http.Cookie{ + Name: "session_id", + Value: sessionID, + Expires: expiresAt, + HttpOnly: true, + Secure: false, // False for local development without TLS + SameSite: http.SameSiteLaxMode, + Path: "/", + }) +} + +func ClearSessionCookie(w http.ResponseWriter) { + http.SetCookie(w, &http.Cookie{ + Name: "session_id", + Value: "", + Expires: time.Unix(0, 0), + HttpOnly: true, + Secure: false, // False for local development without TLS + SameSite: http.SameSiteLaxMode, + Path: "/", + }) +} diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go new file mode 100644 index 0000000..5469bfa --- /dev/null +++ b/internal/handlers/auth.go @@ -0,0 +1,56 @@ +package handlers + +import ( + "net/http" + + "malago/internal/auth" + "malago/internal/templates" +) + +type AuthHandler struct { + authService *auth.Service +} + +func NewAuthHandler(authService *auth.Service) *AuthHandler { + return &AuthHandler{authService: authService} +} + +// Render the login/register pages here (assuming you have these templates) + +func (h *AuthHandler) HandleLogin(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + + session, err := h.authService.Login(r.Context(), username, password) + if err != nil { + // Just handle generically for now, perhaps via HTMX toast + http.Error(w, "invalid credentials", http.StatusUnauthorized) + return + } + + auth.SetSessionCookie(w, session.ID, session.ExpiresAt) + + // HTMX-friendly redirect to root or previous page + w.Header().Set("HX-Redirect", "/") + http.Redirect(w, r, "/", http.StatusFound) +} + +func (h *AuthHandler) HandleLogout(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("session_id") + if err == nil { + _ = h.authService.Logout(r.Context(), cookie.Value) + } + + auth.ClearSessionCookie(w) + w.Header().Set("HX-Redirect", "/") + http.Redirect(w, r, "/", http.StatusFound) +} + +func (h *AuthHandler) HandleLoginPage(w http.ResponseWriter, r *http.Request) { + templates.Login().Render(r.Context(), w) +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 0000000..b1a95d4 --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,82 @@ +package middleware + +import ( + "context" + "net/http" + "strings" + + "malago/internal/auth" + "malago/internal/database" +) + +type contextKey string + +const ( + UserContextKey contextKey = "user" +) + +func Auth(authService *auth.Service) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("session_id") + if err != nil { + // No session cookie, user is unauthenticated. Proceed, but not logged in. + next.ServeHTTP(w, r) + return + } + + user, err := authService.ValidateSession(r.Context(), cookie.Value) + if err != nil { + // Invalid session, proceed as unauthenticated + // Might also want to clear the invalid cookie here + auth.ClearSessionCookie(w) + next.ServeHTTP(w, r) + return + } + + // Valid session, bind user to context + ctx := context.WithValue(r.Context(), UserContextKey, user) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequireAuth ensures that a valid user is in the context, otherwise unauthorized +func RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, ok := r.Context().Value(UserContextKey).(*database.User) + if !ok || user == nil { + if strings.HasPrefix(r.URL.Path, "/api/") { + w.Header().Set("HX-Redirect", "/login") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } else { + http.Redirect(w, r, "/login", http.StatusFound) + } + return + } + next.ServeHTTP(w, r) + }) +} + +// RequireGlobalAuth ensures that a valid user is in the context for all routes except login and static +func RequireGlobalAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Allow unauthenticated access to login and static files + if r.URL.Path == "/login" || strings.HasPrefix(r.URL.Path, "/static/") { + next.ServeHTTP(w, r) + return + } + + user, ok := r.Context().Value(UserContextKey).(*database.User) + if !ok || user == nil { + if strings.HasPrefix(r.URL.Path, "/api/") || r.Header.Get("HX-Request") == "true" { + w.Header().Set("HX-Redirect", "/login") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + } else { + http.Redirect(w, r, "/login", http.StatusFound) + } + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/templates/auth.templ b/internal/templates/auth.templ new file mode 100644 index 0000000..622f52c --- /dev/null +++ b/internal/templates/auth.templ @@ -0,0 +1,18 @@ +package templates + +templ Login() { + @Layout("Login") { +

Login

+
+
+ + +
+
+ + +
+ +
+ } +} diff --git a/internal/templates/auth_templ.go b/internal/templates/auth_templ.go new file mode 100644 index 0000000..5b39472 --- /dev/null +++ b/internal/templates/auth_templ.go @@ -0,0 +1,58 @@ +// Code generated by templ - DO NOT EDIT. + +// templ: version: v0.3.1001 +package templates + +//lint:file-ignore SA4006 This context is only used if a nested component is present. + +import "github.com/a-h/templ" +import templruntime "github.com/a-h/templ/runtime" + +func Login() templ.Component { + return templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { + templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context + if templ_7745c5c3_CtxErr := ctx.Err(); templ_7745c5c3_CtxErr != nil { + return templ_7745c5c3_CtxErr + } + templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W) + if !templ_7745c5c3_IsBuffer { + defer func() { + templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer) + if templ_7745c5c3_Err == nil { + templ_7745c5c3_Err = templ_7745c5c3_BufErr + } + }() + } + ctx = templ.InitializeContext(ctx) + templ_7745c5c3_Var1 := templ.GetChildren(ctx) + if templ_7745c5c3_Var1 == nil { + templ_7745c5c3_Var1 = templ.NopComponent + } + ctx = templ.ClearChildren(ctx) + templ_7745c5c3_Var2 := templruntime.GeneratedTemplate(func(templ_7745c5c3_Input templruntime.GeneratedComponentInput) (templ_7745c5c3_Err error) { + templ_7745c5c3_W, ctx := templ_7745c5c3_Input.Writer, templ_7745c5c3_Input.Context + templ_7745c5c3_Buffer, templ_7745c5c3_IsBuffer := templruntime.GetBuffer(templ_7745c5c3_W) + if !templ_7745c5c3_IsBuffer { + defer func() { + templ_7745c5c3_BufErr := templruntime.ReleaseBuffer(templ_7745c5c3_Buffer) + if templ_7745c5c3_Err == nil { + templ_7745c5c3_Err = templ_7745c5c3_BufErr + } + }() + } + ctx = templ.InitializeContext(ctx) + templ_7745c5c3_Err = templruntime.WriteString(templ_7745c5c3_Buffer, 1, "

Login

") + if templ_7745c5c3_Err != nil { + return templ_7745c5c3_Err + } + return nil + }) + templ_7745c5c3_Err = Layout("Login").Render(templ.WithChildren(ctx, templ_7745c5c3_Var2), templ_7745c5c3_Buffer) + if templ_7745c5c3_Err != nil { + return templ_7745c5c3_Err + } + return nil + }) +} + +var _ = templruntime.GeneratedTemplate