fix: replace nil context with context.TODO

This commit is contained in:
2026-06-11 17:11:47 +02:00
parent ed90b5c7aa
commit 25471e0bd5
18 changed files with 798 additions and 52 deletions

View File

@@ -201,7 +201,7 @@ func (h *AnimeHandler) HandleProducers(c *gin.Context) {
res, err := h.svc.GetProducers(c.Request.Context(), q, page, limit)
if err != nil {
observability.Warn(
observability.WarnContext(c.Request.Context(),
"producers_fetch_failed",
"anime",
"",
@@ -270,7 +270,7 @@ func (h *AnimeHandler) HandleCatalogTopPickForYou(c *gin.Context) {
data, err := h.svc.GetTopPickForYou(c.Request.Context(), userID)
if err != nil {
observability.Warn(
observability.WarnContext(c.Request.Context(),
"top_pick_for_you_fetch_failed",
"anime",
"",
@@ -321,7 +321,7 @@ func (h *AnimeHandler) HandleDiscoverTopPicksForYou(c *gin.Context) {
data, err := h.svc.GetTopPicksForYou(c.Request.Context(), userID)
if err != nil {
observability.Warn(
observability.WarnContext(c.Request.Context(),
"top_picks_for_you_fetch_failed",
"anime",
"",
@@ -375,7 +375,7 @@ func (h *AnimeHandler) renderDiscoverSection(c *gin.Context, section string) {
}
func (h *AnimeHandler) abortSectionFetch(c *gin.Context, event string, userID string, section string, err error) {
observability.Warn(
observability.WarnContext(c.Request.Context(),
event,
"anime",
"",
@@ -407,7 +407,7 @@ func (h *AnimeHandler) HandleScheduleSection(c *gin.Context) {
if err != nil {
prevYear, prevWeek := adjacentISOWeek(year, week, -1)
nextYear, nextWeek := adjacentISOWeek(year, week, 1)
observability.Warn(
observability.WarnContext(c.Request.Context(),
"animeschedule_fetch_failed",
"anime",
"",

View File

@@ -10,6 +10,7 @@ import (
"mal/internal/config"
"mal/internal/database"
"mal/internal/episodes"
"mal/internal/observability"
"mal/internal/playback"
"mal/internal/server"
"mal/internal/watchlist"
@@ -18,10 +19,14 @@ import (
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/render"
"go.uber.org/fx"
"go.uber.org/fx/fxevent"
)
func NewApp() *fx.App {
return fx.New(
fx.WithLogger(func() fxevent.Logger {
return observability.NewFxLogger()
}),
config.Module,
database.Module,
audit.Module,

View File

@@ -38,6 +38,7 @@ func ProvideQueries(sqlDB *sql.DB) *db.Queries {
func RunMigrations(sqlDB *sql.DB) error {
goose.SetBaseFS(migrationsFS)
goose.SetLogger(goose.NopLogger())
if err := goose.SetDialect("sqlite3"); err != nil {
return fmt.Errorf("failed to set goose dialect: %w", err)
@@ -48,6 +49,13 @@ func RunMigrations(sqlDB *sql.DB) error {
return fmt.Errorf("failed to run migrations: %w", err)
}
version, err := goose.GetDBVersion(sqlDB)
if err != nil {
return fmt.Errorf("failed to get database migration version: %w", err)
}
observability.Info("db_migrations_complete", "database", "", map[string]any{"version": version})
return nil
}
func RunMigrationsAndFixes(sqlDB *sql.DB) error {

View File

@@ -0,0 +1,52 @@
package observability
import (
"go.uber.org/fx/fxevent"
)
type fxLogger struct{}
func NewFxLogger() fxevent.Logger {
return fxLogger{}
}
func (fxLogger) LogEvent(event fxevent.Event) {
switch e := event.(type) {
case *fxevent.Provided:
if e.Err != nil {
Error("fx_provide_failed", "fx", "", map[string]any{"constructor": e.ConstructorName}, e.Err)
}
case *fxevent.Invoked:
if e.Err != nil {
Error("fx_invoke_failed", "fx", "", map[string]any{"function": e.FunctionName}, e.Err)
}
case *fxevent.Run:
if e.Err != nil {
Error("fx_run_failed", "fx", "", map[string]any{"function": e.Name, "kind": e.Kind}, e.Err)
}
case *fxevent.OnStartExecuted:
if e.Err != nil {
Error("fx_on_start_failed", "fx", "", map[string]any{"caller": e.CallerName, "function": e.FunctionName, "runtime": e.Runtime}, e.Err)
}
case *fxevent.OnStopExecuted:
if e.Err != nil {
Error("fx_on_stop_failed", "fx", "", map[string]any{"caller": e.CallerName, "function": e.FunctionName, "runtime": e.Runtime}, e.Err)
}
case *fxevent.Started:
if e.Err != nil {
Error("fx_start_failed", "fx", "", nil, e.Err)
}
case *fxevent.Stopped:
if e.Err != nil {
Error("fx_stop_failed", "fx", "", nil, e.Err)
}
case *fxevent.RollingBack:
if e.StartErr != nil {
Error("fx_rollback_start", "fx", "", nil, e.StartErr)
}
case *fxevent.RolledBack:
if e.Err != nil {
Error("fx_rollback_failed", "fx", "", nil, e.Err)
}
}
}

View File

@@ -1,5 +1,7 @@
package observability
import "context"
// Small helpers to keep logging consistent and low-friction across the codebase.
func Info(event string, component string, message string, fields map[string]any) {
@@ -10,6 +12,14 @@ func Warn(event string, component string, message string, fields map[string]any,
LogJSON(LogLevelWarn, event, component, message, fields, err)
}
func WarnContext(ctx context.Context, event string, component string, message string, fields map[string]any, err error) {
LogContext(ctx, LogLevelWarn, event, component, message, fields, err)
}
func Error(event string, component string, message string, fields map[string]any, err error) {
LogJSON(LogLevelError, event, component, message, fields, err)
}
func ErrorContext(ctx context.Context, event string, component string, message string, fields map[string]any, err error) {
LogContext(ctx, LogLevelError, event, component, message, fields, err)
}

View File

@@ -0,0 +1,35 @@
package observability
import (
"errors"
"fmt"
"strings"
"testing"
)
func TestWarnEnrichesSourceAndErrorContext(t *testing.T) {
fields := enrichFields(LogLevelWarn, map[string]any{"anime_id": 123}, wrappedError())
if fields["anime_id"] != 123 {
t.Fatalf("expected existing field to survive, got %#v", fields["anime_id"])
}
source, ok := fields["source"].(string)
if !ok || source == "" {
t.Fatalf("expected source field, got %#v", fields["source"])
}
errorType, ok := fields["error_type"].(string)
if !ok || errorType == "" {
t.Fatalf("expected error_type field, got %#v", fields["error_type"])
}
chain, ok := fields["error_chain"].(string)
if !ok || !strings.Contains(chain, "query anime") || !strings.Contains(chain, "db timeout") {
t.Fatalf("expected wrapped error chain, got %#v", fields["error_chain"])
}
}
func wrappedError() error {
return fmt.Errorf("query anime: %w", errors.New("db timeout"))
}

View File

@@ -2,11 +2,30 @@
package observability
import (
"encoding/json"
"context"
"errors"
"fmt"
"log"
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"time"
)
const (
ansiReset = "\x1b[0m"
ansiBlue = "\x1b[36m"
ansiYellow = "\x1b[33m"
ansiRed = "\x1b[31m"
)
var colorLogs = shouldColorLogs()
type LogLevel string
const (
@@ -25,11 +44,17 @@ type LogEvent struct {
Component string `json:"component,omitempty"`
}
func init() {
log.SetFlags(0)
}
func LogJSON(level LogLevel, event string, component string, message string, fields map[string]any, err error) {
errorValue := ""
if err != nil {
errorValue = err.Error()
}
LogContext(context.TODO(), level, event, component, message, fields, err)
}
func LogContext(ctx context.Context, level LogLevel, event string, component string, message string, fields map[string]any, err error) {
fields = enrichFields(level, fields, err)
fields = enrichRequestFields(ctx, fields)
entry := LogEvent{
TS: time.Now().UTC().Format(time.RFC3339Nano),
@@ -37,23 +62,370 @@ func LogJSON(level LogLevel, event string, component string, message string, fie
Event: event,
Message: message,
Fields: fields,
Error: errorValue,
Component: component,
}
// Best-effort. If encoding fails, fall back to a minimal line.
bytes, marshalErr := json.Marshal(entry)
if marshalErr != nil {
// Keep output JSON-only even on failures by constructing a minimal entry.
// Marshal individual strings to ensure proper escaping.
tsBytes, _ := json.Marshal(time.Now().UTC().Format(time.RFC3339Nano))
levelBytes, _ := json.Marshal(level)
eventBytes, _ := json.Marshal("log_marshal_failed")
componentBytes, _ := json.Marshal(component)
errBytes, _ := json.Marshal(marshalErr.Error())
log.Printf(`{"ts":%s,"level":%s,"event":%s,"component":%s,"error":%s}`, tsBytes, levelBytes, eventBytes, componentBytes, errBytes)
if err != nil {
entry.Error = err.Error()
}
log.Print(formatLogEntry(entry))
}
func enrichRequestFields(ctx context.Context, fields map[string]any) map[string]any {
requestContext, ok := RequestContextFromContext(ctx)
if !ok {
return fields
}
enriched := cloneFields(fields)
if enriched == nil {
enriched = make(map[string]any, 3)
}
if requestContext.ID != "" {
if _, exists := enriched["request_id"]; !exists {
enriched["request_id"] = requestContext.ID
}
}
if requestContext.Path != "" {
if _, exists := enriched["request_path"]; !exists {
enriched["request_path"] = requestContext.Path
}
}
if requestContext.Route != "" && requestContext.Route != requestContext.Path {
if _, exists := enriched["request_route"]; !exists {
enriched["request_route"] = requestContext.Route
}
}
return enriched
}
func enrichFields(level LogLevel, fields map[string]any, err error) map[string]any {
if level == LogLevelInfo {
return fields
}
enriched := cloneFields(fields)
if enriched == nil {
enriched = make(map[string]any, 3)
}
if _, exists := enriched["source"]; !exists {
if source := callerSource(); source != "" {
enriched["source"] = source
}
}
if err != nil {
if _, exists := enriched["error_type"]; !exists {
if errorType := formatErrorType(err); errorType != "" {
enriched["error_type"] = errorType
}
}
if _, exists := enriched["error_chain"]; !exists {
if chain := formatErrorChain(err); chain != "" {
enriched["error_chain"] = chain
}
}
}
return enriched
}
func callerSource() string {
pcs := make([]uintptr, 8)
n := runtime.Callers(3, pcs)
frames := runtime.CallersFrames(pcs[:n])
for {
frame, more := frames.Next()
if !strings.Contains(frame.File, "/internal/observability/") {
return filepath.Base(frame.File) + ":" + strconv.Itoa(frame.Line)
}
if !more {
return ""
}
}
}
func formatErrorType(err error) string {
errType := reflect.TypeOf(err)
if errType == nil {
return ""
}
return errType.String()
}
func formatErrorChain(err error) string {
parts := make([]string, 0, 4)
for current := err; current != nil; current = errors.Unwrap(current) {
parts = append(parts, current.Error())
if len(parts) == 4 {
break
}
}
if len(parts) <= 1 {
return ""
}
return strings.Join(parts, " -> ")
}
func formatLogEntry(entry LogEvent) string {
if entry.Event == "http_request" {
return formatHTTPRequestLog(entry)
}
parts := []string{entry.TS, formatLogLevel(entry.Level), entry.Event}
if entry.Component != "" {
parts = append(parts, "component="+entry.Component)
}
if entry.Message != "" {
parts = append(parts, quoteIfNeeded(entry.Message))
}
if len(entry.Fields) > 0 {
keys := make([]string, 0, len(entry.Fields))
for key := range entry.Fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
parts = append(parts, key+"="+formatFieldValue(entry.Fields[key]))
}
}
if entry.Error != "" {
parts = append(parts, "error="+quoteIfNeeded(entry.Error))
}
return strings.Join(parts, " ")
}
func formatHTTPRequestLog(entry LogEvent) string {
fields := cloneFields(entry.Fields)
status := popField(fields, "status")
method := popField(fields, "method")
path := popField(fields, "path")
duration := popField(fields, "duration_ms")
bytes := popField(fields, "bytes")
route := popField(fields, "route")
query := popField(fields, "query")
clientIP := popField(fields, "client_ip")
parts := []string{entry.TS, formatLogLevel(entry.Level), "http"}
if status != "" {
parts = append(parts, status)
}
if method != "" || path != "" {
parts = append(parts, strings.TrimSpace(method+" "+path))
}
if duration != "" {
parts = append(parts, duration)
}
if bytes != "" {
parts = append(parts, bytes)
}
if route != "" {
parts = append(parts, "route="+route)
}
if query != "" {
parts = append(parts, "query="+quoteIfNeeded(query))
}
if clientIP != "" && !isLocalClientIP(clientIP) {
parts = append(parts, "ip="+clientIP)
}
appendSortedFields(&parts, fields)
if entry.Error != "" {
parts = append(parts, "error="+quoteIfNeeded(entry.Error))
}
return strings.Join(parts, " ")
}
func appendSortedFields(parts *[]string, fields map[string]any) {
if len(fields) == 0 {
return
}
log.Print(string(bytes))
keys := make([]string, 0, len(fields))
for key := range fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
*parts = append(*parts, key+"="+formatFieldValue(fields[key]))
}
}
func cloneFields(fields map[string]any) map[string]any {
if len(fields) == 0 {
return nil
}
copyFields := make(map[string]any, len(fields))
for key, value := range fields {
copyFields[key] = value
}
return copyFields
}
func popField(fields map[string]any, key string) string {
if len(fields) == 0 {
return ""
}
value, ok := fields[key]
if !ok {
return ""
}
delete(fields, key)
return formatInlineField(key, value)
}
func formatInlineField(key string, value any) string {
switch key {
case "status":
return fmt.Sprint(value)
case "duration_ms":
return formatDurationMillis(value)
case "bytes":
return formatBytes(value)
default:
if text, ok := value.(string); ok {
return text
}
return fmt.Sprint(value)
}
}
func formatDurationMillis(value any) string {
ms, ok := toFloat64(value)
if !ok {
return fmt.Sprint(value)
}
return strconv.FormatFloat(ms, 'f', -1, 64) + "ms"
}
func formatBytes(value any) string {
bytesValue, ok := toFloat64(value)
if !ok {
return fmt.Sprint(value)
}
if bytesValue < 1024 {
return strconv.FormatFloat(bytesValue, 'f', -1, 64) + "B"
}
if bytesValue < 1024*1024 {
return strconv.FormatFloat(bytesValue/1024, 'f', 1, 64) + "KB"
}
return strconv.FormatFloat(bytesValue/(1024*1024), 'f', 1, 64) + "MB"
}
func toFloat64(value any) (float64, bool) {
switch v := value.(type) {
case int:
return float64(v), true
case int32:
return float64(v), true
case int64:
return float64(v), true
case float32:
return float64(v), true
case float64:
return v, true
default:
return 0, false
}
}
func isLocalClientIP(value string) bool {
parsed := net.ParseIP(value)
if parsed == nil {
return false
}
return parsed.IsLoopback()
}
func formatLogLevel(level LogLevel) string {
if colorLogs {
switch level {
case LogLevelWarn:
return ansiYellow + "WARN" + ansiReset
case LogLevelError:
return ansiRed + "ERROR" + ansiReset
default:
return ansiBlue + "INFO" + ansiReset
}
}
switch level {
case LogLevelWarn:
return "WARN"
case LogLevelError:
return "ERROR"
default:
return "INFO"
}
}
func shouldColorLogs() bool {
if strings.TrimSpace(os.Getenv("NO_COLOR")) != "" {
return false
}
if strings.EqualFold(strings.TrimSpace(os.Getenv("TERM")), "dumb") {
return false
}
info, err := os.Stderr.Stat()
if err != nil {
return false
}
return info.Mode()&os.ModeCharDevice != 0
}
func formatFieldValue(value any) string {
switch v := value.(type) {
case string:
return quoteIfNeeded(v)
case time.Duration:
return v.String()
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case fmt.Stringer:
return quoteIfNeeded(v.String())
default:
return quoteIfNeeded(fmt.Sprint(value))
}
}
func quoteIfNeeded(value string) string {
if value == "" {
return `""`
}
for _, r := range value {
if r == '=' || r == ' ' || r == '\t' || r == '\n' || r == '"' {
return strconv.Quote(value)
}
}
return value
}

View File

@@ -0,0 +1,36 @@
package observability
import (
"strings"
"testing"
)
func TestFormatLogEntryFormatsHTTPRequestCompactly(t *testing.T) {
line := formatLogEntry(LogEvent{
TS: "2026-06-11T12:57:39.557972Z",
Level: LogLevelInfo,
Event: "http_request",
Fields: map[string]any{
"bytes": 56198,
"client_ip": "127.0.0.1",
"duration_ms": 9.419,
"method": "GET",
"path": "/api/catalog/top-pick",
"status": 200,
},
})
checks := []string{
"2026-06-11T12:57:39.557972Z INFO http 200 GET /api/catalog/top-pick 9.419ms 54.9KB",
}
for _, check := range checks {
if !strings.Contains(line, check) {
t.Fatalf("line %q missing %q", line, check)
}
}
if strings.Contains(line, "client_ip=") {
t.Fatalf("line should omit loopback ip: %q", line)
}
}

View File

@@ -0,0 +1,32 @@
package observability
import "context"
type requestContextKey struct{}
type RequestContext struct {
ID string
Path string
Route string
}
func WithRequestContext(ctx context.Context, requestID string, path string, route string) context.Context {
if ctx == nil {
return nil
}
return context.WithValue(ctx, requestContextKey{}, RequestContext{
ID: requestID,
Path: path,
Route: route,
})
}
func RequestContextFromContext(ctx context.Context) (RequestContext, bool) {
if ctx == nil {
return RequestContext{}, false
}
requestContext, ok := ctx.Value(requestContextKey{}).(RequestContext)
return requestContext, ok
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"mal/internal/domain"
"mal/internal/observability"
"mal/internal/server"
netutil "mal/pkg/net"
"net/http"
@@ -302,6 +303,8 @@ func (h *PlaybackHandler) HandleProxyStream(c *gin.Context) {
resp, err := h.streamingClient.Do(req)
if err != nil {
observability.ErrorContext(c.Request.Context(), "proxy_stream_upstream_failed", "playback", "", map[string]any{"target_url": targetURL}, err)
_ = c.Error(err).SetType(gin.ErrorTypePrivate)
c.Status(http.StatusBadGateway)
return
}
@@ -310,11 +313,15 @@ func (h *PlaybackHandler) HandleProxyStream(c *gin.Context) {
if isHLSPlaylistResponse(targetURL, resp.Header) {
body, err := io.ReadAll(io.LimitReader(resp.Body, netutil.MiB2))
if err != nil {
observability.ErrorContext(c.Request.Context(), "proxy_stream_playlist_read_failed", "playback", "", map[string]any{"target_url": targetURL}, err)
_ = c.Error(err).SetType(gin.ErrorTypePrivate)
c.Status(http.StatusBadGateway)
return
}
rewritten, err := h.rewriteHLSPlaylist(string(body), targetURL, referer)
if err != nil {
observability.ErrorContext(c.Request.Context(), "proxy_stream_playlist_rewrite_failed", "playback", "", map[string]any{"target_url": targetURL}, err)
_ = c.Error(err).SetType(gin.ErrorTypePrivate)
c.Status(http.StatusBadGateway)
return
}
@@ -484,6 +491,8 @@ func (h *PlaybackHandler) HandleProxySubtitle(c *gin.Context) {
resp, err := h.proxyClient.Do(req)
if err != nil {
observability.ErrorContext(c.Request.Context(), "proxy_subtitle_upstream_failed", "playback", "", map[string]any{"target_url": targetURL}, err)
_ = c.Error(err).SetType(gin.ErrorTypePrivate)
c.Status(http.StatusBadGateway)
return
}
@@ -491,6 +500,8 @@ func (h *PlaybackHandler) HandleProxySubtitle(c *gin.Context) {
body, err := io.ReadAll(io.LimitReader(resp.Body, netutil.MiB2))
if err != nil {
observability.ErrorContext(c.Request.Context(), "proxy_subtitle_read_failed", "playback", "", map[string]any{"target_url": targetURL}, err)
_ = c.Error(err).SetType(gin.ErrorTypePrivate)
c.Status(http.StatusBadGateway)
return
}

View File

@@ -31,23 +31,40 @@ func RequestLogger(metrics *observability.Metrics) gin.HandlerFunc {
level = observability.LogLevelWarn
}
observability.LogJSON(
fields := map[string]any{
"client_ip": c.ClientIP(),
"duration_ms": float64(duration.Microseconds()) / 1000,
"method": c.Request.Method,
"path": path,
"request_id": c.Writer.Header().Get(requestIDHeader),
"status": status,
}
privateErrors := c.Errors.ByType(gin.ErrorTypePrivate)
var logErr error
if len(privateErrors) > 0 {
logErr = privateErrors.Last().Err
}
if route != path {
fields["route"] = route
}
if query != "" {
fields["query"] = query
}
if size := c.Writer.Size(); size >= 0 {
fields["bytes"] = size
}
if errors := privateErrors.String(); errors != "" {
fields["errors"] = errors
}
observability.LogContext(
c.Request.Context(),
level,
"http_request",
"http",
"",
map[string]any{
"method": c.Request.Method,
"route": route,
"path": path,
"query": query,
"status": status,
"duration_ms": float64(duration.Microseconds()) / 1000,
"bytes": c.Writer.Size(),
"client_ip": c.ClientIP(),
"errors": c.Errors.ByType(gin.ErrorTypePrivate).String(),
},
nil,
c.Request.Method+" "+path,
fields,
logErr,
)
}
}

View File

@@ -0,0 +1,30 @@
package server
import (
"mal/internal/observability"
"strings"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const requestIDHeader = "X-Request-ID"
func RequestContextMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
requestID := strings.TrimSpace(c.GetHeader(requestIDHeader))
if requestID == "" {
requestID = uuid.NewString()
}
path := c.Request.URL.Path
route := c.FullPath()
if route == "" {
route = path
}
c.Writer.Header().Set(requestIDHeader, requestID)
c.Request = c.Request.WithContext(observability.WithRequestContext(c.Request.Context(), requestID, path, route))
c.Next()
}
}

View File

@@ -27,7 +27,18 @@ func RespondError(c *gin.Context, status int, event string, component string, me
if status >= http.StatusInternalServerError {
level = observability.LogLevelError
}
observability.LogJSON(level, event, component, "", fields, err)
if fields == nil {
fields = make(map[string]any, 2)
}
if _, exists := fields["request_path"]; !exists {
fields["request_path"] = c.Request.URL.Path
}
if route := c.FullPath(); route != "" && route != c.Request.URL.Path {
if _, exists := fields["request_route"]; !exists {
fields["request_route"] = route
}
}
observability.LogContext(c.Request.Context(), level, event, component, "", fields, err)
RespondHTMLOrJSONError(c, status, message)
}

View File

@@ -27,7 +27,7 @@ func ProvideRouter(cfg config.Config, htmlRender render.HTMLRender, metrics *obs
gin.SetMode(cfg.GinMode)
}
r := gin.New()
r.Use(CORSMiddlewareWithConfig(cfg), audit.ContextMiddleware(), RequestLogger(metrics), gin.Recovery())
r.Use(CORSMiddlewareWithConfig(cfg), RequestContextMiddleware(), audit.ContextMiddleware(), RequestLogger(metrics), gin.Recovery())
r.Static("/static", "./static")
r.Static("/dist", "./dist")
r.GET("/metrics", gin.WrapH(metrics.Handler()))

View File

@@ -44,6 +44,7 @@ func TestRequestLoggerUsesMatchedRoute(t *testing.T) {
defer log.SetOutput(previousOutput)
router := gin.New()
router.Use(RequestContextMiddleware())
router.Use(RequestLogger(observability.NewMetrics()))
router.GET("/anime/:id", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
@@ -59,13 +60,54 @@ func TestRequestLoggerUsesMatchedRoute(t *testing.T) {
}
logLine := string(output)
if !strings.Contains(logLine, `"event":"http_request"`) {
t.Fatalf("log line missing event: %s", logLine)
if !strings.Contains(logLine, " INFO http 200 GET /anime/1") {
t.Fatalf("log line missing compact http summary: %s", logLine)
}
if !strings.Contains(logLine, `"route":"/anime/:id"`) {
if !strings.Contains(logLine, " route=/anime/:id") {
t.Fatalf("log line missing route: %s", logLine)
}
if !strings.Contains(logLine, `"status":200`) {
t.Fatalf("log line missing status: %s", logLine)
if !strings.Contains(logLine, " request_id=") {
t.Fatalf("log line missing request id: %s", logLine)
}
if strings.Contains(logLine, `"GET /anime/1"`) {
t.Fatalf("log line should not duplicate request summary: %s", logLine)
}
if rec.Header().Get(requestIDHeader) == "" {
t.Fatalf("expected %s response header to be set", requestIDHeader)
}
}
func TestRespondErrorIncludesRequestContext(t *testing.T) {
gin.SetMode(gin.TestMode)
var logs bytes.Buffer
previousOutput := log.Writer()
log.SetOutput(&logs)
defer log.SetOutput(previousOutput)
router := gin.New()
router.Use(RequestContextMiddleware())
router.GET("/anime/:id", func(c *gin.Context) {
RespondError(c, http.StatusInternalServerError, "anime_lookup_failed", "anime", "failed", nil, context.DeadlineExceeded)
})
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/anime/1", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
output, err := io.ReadAll(&logs)
if err != nil {
t.Fatalf("read logs: %v", err)
}
logLine := string(output)
if !strings.Contains(logLine, " request_id=") {
t.Fatalf("log line missing request id: %s", logLine)
}
if !strings.Contains(logLine, " request_path=/anime/1") {
t.Fatalf("log line missing request path: %s", logLine)
}
if !strings.Contains(logLine, " request_route=/anime/:id") {
t.Fatalf("log line missing request route: %s", logLine)
}
}