From ecb15782c87c0fb0bb8d3c50a7d0a60a65caef46 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Fri, 10 Apr 2026 17:25:27 +0200 Subject: [PATCH] security: enforce csrf on writes --- internal/server/routes.go | 5 +- internal/shared/middleware/logging.go | 76 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 internal/shared/middleware/logging.go diff --git a/internal/server/routes.go b/internal/server/routes.go index e729ed2..9c05764 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -73,7 +73,8 @@ func NewRouter(cfg Config) http.Handler { mux.HandleFunc("/api/watchlist/", watchlistHandler.HandleDeleteWatchlist) mux.HandleFunc("/watchlist", watchlistHandler.HandleGetWatchlist) - // Wrap mux with global auth checking, THEN auth context parsing - protectedHandler := middleware.RequireGlobalAuth(mux) + // Wrap mux with global CSRF origin verification and auth checking, + // THEN auth context parsing. + protectedHandler := middleware.RequireGlobalAuth(middleware.VerifyOrigin(mux)) return middleware.Auth(cfg.AuthService)(protectedHandler) } diff --git a/internal/shared/middleware/logging.go b/internal/shared/middleware/logging.go new file mode 100644 index 0000000..cd766f5 --- /dev/null +++ b/internal/shared/middleware/logging.go @@ -0,0 +1,76 @@ +package middleware + +import ( + "bufio" + "fmt" + "log" + "net" + "net/http" + "time" +) + +type statusRecorder struct { + http.ResponseWriter + statusCode int + wroteHeader bool +} + +func newStatusRecorder(w http.ResponseWriter) *statusRecorder { + return &statusRecorder{ + ResponseWriter: w, + statusCode: http.StatusOK, + } +} + +func (rw *statusRecorder) WriteHeader(code int) { + if rw.wroteHeader { + return + } + rw.statusCode = code + rw.wroteHeader = true + rw.ResponseWriter.WriteHeader(code) +} + +func (rw *statusRecorder) Write(b []byte) (int, error) { + if !rw.wroteHeader { + rw.WriteHeader(http.StatusOK) + } + return rw.ResponseWriter.Write(b) +} + +func (rw *statusRecorder) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (rw *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("response writer does not support hijacking") + } + return hijacker.Hijack() +} + +func (rw *statusRecorder) Push(target string, opts *http.PushOptions) error { + pusher, ok := rw.ResponseWriter.(http.Pusher) + if !ok { + return http.ErrNotSupported + } + return pusher.Push(target, opts) +} + +func (rw *statusRecorder) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + +func RequestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + recorder := newStatusRecorder(w) + + next.ServeHTTP(recorder, r) + + log.Printf("%s %s %d %s", r.Method, r.URL.Path, recorder.statusCode, time.Since(start)) + }) +}