fix: replace nil context with context.TODO

This commit is contained in:
2026-06-11 17:11:47 +02:00
parent ed90b5c7aa
commit 25471e0bd5
18 changed files with 798 additions and 52 deletions

View File

@@ -0,0 +1,52 @@
package observability
import (
"go.uber.org/fx/fxevent"
)
type fxLogger struct{}
func NewFxLogger() fxevent.Logger {
return fxLogger{}
}
func (fxLogger) LogEvent(event fxevent.Event) {
switch e := event.(type) {
case *fxevent.Provided:
if e.Err != nil {
Error("fx_provide_failed", "fx", "", map[string]any{"constructor": e.ConstructorName}, e.Err)
}
case *fxevent.Invoked:
if e.Err != nil {
Error("fx_invoke_failed", "fx", "", map[string]any{"function": e.FunctionName}, e.Err)
}
case *fxevent.Run:
if e.Err != nil {
Error("fx_run_failed", "fx", "", map[string]any{"function": e.Name, "kind": e.Kind}, e.Err)
}
case *fxevent.OnStartExecuted:
if e.Err != nil {
Error("fx_on_start_failed", "fx", "", map[string]any{"caller": e.CallerName, "function": e.FunctionName, "runtime": e.Runtime}, e.Err)
}
case *fxevent.OnStopExecuted:
if e.Err != nil {
Error("fx_on_stop_failed", "fx", "", map[string]any{"caller": e.CallerName, "function": e.FunctionName, "runtime": e.Runtime}, e.Err)
}
case *fxevent.Started:
if e.Err != nil {
Error("fx_start_failed", "fx", "", nil, e.Err)
}
case *fxevent.Stopped:
if e.Err != nil {
Error("fx_stop_failed", "fx", "", nil, e.Err)
}
case *fxevent.RollingBack:
if e.StartErr != nil {
Error("fx_rollback_start", "fx", "", nil, e.StartErr)
}
case *fxevent.RolledBack:
if e.Err != nil {
Error("fx_rollback_failed", "fx", "", nil, e.Err)
}
}
}

View File

@@ -1,5 +1,7 @@
package observability
import "context"
// Small helpers to keep logging consistent and low-friction across the codebase.
func Info(event string, component string, message string, fields map[string]any) {
@@ -10,6 +12,14 @@ func Warn(event string, component string, message string, fields map[string]any,
LogJSON(LogLevelWarn, event, component, message, fields, err)
}
func WarnContext(ctx context.Context, event string, component string, message string, fields map[string]any, err error) {
LogContext(ctx, LogLevelWarn, event, component, message, fields, err)
}
func Error(event string, component string, message string, fields map[string]any, err error) {
LogJSON(LogLevelError, event, component, message, fields, err)
}
func ErrorContext(ctx context.Context, event string, component string, message string, fields map[string]any, err error) {
LogContext(ctx, LogLevelError, event, component, message, fields, err)
}

View File

@@ -0,0 +1,35 @@
package observability
import (
"errors"
"fmt"
"strings"
"testing"
)
func TestWarnEnrichesSourceAndErrorContext(t *testing.T) {
fields := enrichFields(LogLevelWarn, map[string]any{"anime_id": 123}, wrappedError())
if fields["anime_id"] != 123 {
t.Fatalf("expected existing field to survive, got %#v", fields["anime_id"])
}
source, ok := fields["source"].(string)
if !ok || source == "" {
t.Fatalf("expected source field, got %#v", fields["source"])
}
errorType, ok := fields["error_type"].(string)
if !ok || errorType == "" {
t.Fatalf("expected error_type field, got %#v", fields["error_type"])
}
chain, ok := fields["error_chain"].(string)
if !ok || !strings.Contains(chain, "query anime") || !strings.Contains(chain, "db timeout") {
t.Fatalf("expected wrapped error chain, got %#v", fields["error_chain"])
}
}
func wrappedError() error {
return fmt.Errorf("query anime: %w", errors.New("db timeout"))
}

View File

@@ -2,11 +2,30 @@
package observability
import (
"encoding/json"
"context"
"errors"
"fmt"
"log"
"net"
"os"
"path/filepath"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
"time"
)
const (
ansiReset = "\x1b[0m"
ansiBlue = "\x1b[36m"
ansiYellow = "\x1b[33m"
ansiRed = "\x1b[31m"
)
var colorLogs = shouldColorLogs()
type LogLevel string
const (
@@ -25,11 +44,17 @@ type LogEvent struct {
Component string `json:"component,omitempty"`
}
func init() {
log.SetFlags(0)
}
func LogJSON(level LogLevel, event string, component string, message string, fields map[string]any, err error) {
errorValue := ""
if err != nil {
errorValue = err.Error()
}
LogContext(context.TODO(), level, event, component, message, fields, err)
}
func LogContext(ctx context.Context, level LogLevel, event string, component string, message string, fields map[string]any, err error) {
fields = enrichFields(level, fields, err)
fields = enrichRequestFields(ctx, fields)
entry := LogEvent{
TS: time.Now().UTC().Format(time.RFC3339Nano),
@@ -37,23 +62,370 @@ func LogJSON(level LogLevel, event string, component string, message string, fie
Event: event,
Message: message,
Fields: fields,
Error: errorValue,
Component: component,
}
// Best-effort. If encoding fails, fall back to a minimal line.
bytes, marshalErr := json.Marshal(entry)
if marshalErr != nil {
// Keep output JSON-only even on failures by constructing a minimal entry.
// Marshal individual strings to ensure proper escaping.
tsBytes, _ := json.Marshal(time.Now().UTC().Format(time.RFC3339Nano))
levelBytes, _ := json.Marshal(level)
eventBytes, _ := json.Marshal("log_marshal_failed")
componentBytes, _ := json.Marshal(component)
errBytes, _ := json.Marshal(marshalErr.Error())
log.Printf(`{"ts":%s,"level":%s,"event":%s,"component":%s,"error":%s}`, tsBytes, levelBytes, eventBytes, componentBytes, errBytes)
if err != nil {
entry.Error = err.Error()
}
log.Print(formatLogEntry(entry))
}
func enrichRequestFields(ctx context.Context, fields map[string]any) map[string]any {
requestContext, ok := RequestContextFromContext(ctx)
if !ok {
return fields
}
enriched := cloneFields(fields)
if enriched == nil {
enriched = make(map[string]any, 3)
}
if requestContext.ID != "" {
if _, exists := enriched["request_id"]; !exists {
enriched["request_id"] = requestContext.ID
}
}
if requestContext.Path != "" {
if _, exists := enriched["request_path"]; !exists {
enriched["request_path"] = requestContext.Path
}
}
if requestContext.Route != "" && requestContext.Route != requestContext.Path {
if _, exists := enriched["request_route"]; !exists {
enriched["request_route"] = requestContext.Route
}
}
return enriched
}
func enrichFields(level LogLevel, fields map[string]any, err error) map[string]any {
if level == LogLevelInfo {
return fields
}
enriched := cloneFields(fields)
if enriched == nil {
enriched = make(map[string]any, 3)
}
if _, exists := enriched["source"]; !exists {
if source := callerSource(); source != "" {
enriched["source"] = source
}
}
if err != nil {
if _, exists := enriched["error_type"]; !exists {
if errorType := formatErrorType(err); errorType != "" {
enriched["error_type"] = errorType
}
}
if _, exists := enriched["error_chain"]; !exists {
if chain := formatErrorChain(err); chain != "" {
enriched["error_chain"] = chain
}
}
}
return enriched
}
func callerSource() string {
pcs := make([]uintptr, 8)
n := runtime.Callers(3, pcs)
frames := runtime.CallersFrames(pcs[:n])
for {
frame, more := frames.Next()
if !strings.Contains(frame.File, "/internal/observability/") {
return filepath.Base(frame.File) + ":" + strconv.Itoa(frame.Line)
}
if !more {
return ""
}
}
}
func formatErrorType(err error) string {
errType := reflect.TypeOf(err)
if errType == nil {
return ""
}
return errType.String()
}
func formatErrorChain(err error) string {
parts := make([]string, 0, 4)
for current := err; current != nil; current = errors.Unwrap(current) {
parts = append(parts, current.Error())
if len(parts) == 4 {
break
}
}
if len(parts) <= 1 {
return ""
}
return strings.Join(parts, " -> ")
}
func formatLogEntry(entry LogEvent) string {
if entry.Event == "http_request" {
return formatHTTPRequestLog(entry)
}
parts := []string{entry.TS, formatLogLevel(entry.Level), entry.Event}
if entry.Component != "" {
parts = append(parts, "component="+entry.Component)
}
if entry.Message != "" {
parts = append(parts, quoteIfNeeded(entry.Message))
}
if len(entry.Fields) > 0 {
keys := make([]string, 0, len(entry.Fields))
for key := range entry.Fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
parts = append(parts, key+"="+formatFieldValue(entry.Fields[key]))
}
}
if entry.Error != "" {
parts = append(parts, "error="+quoteIfNeeded(entry.Error))
}
return strings.Join(parts, " ")
}
func formatHTTPRequestLog(entry LogEvent) string {
fields := cloneFields(entry.Fields)
status := popField(fields, "status")
method := popField(fields, "method")
path := popField(fields, "path")
duration := popField(fields, "duration_ms")
bytes := popField(fields, "bytes")
route := popField(fields, "route")
query := popField(fields, "query")
clientIP := popField(fields, "client_ip")
parts := []string{entry.TS, formatLogLevel(entry.Level), "http"}
if status != "" {
parts = append(parts, status)
}
if method != "" || path != "" {
parts = append(parts, strings.TrimSpace(method+" "+path))
}
if duration != "" {
parts = append(parts, duration)
}
if bytes != "" {
parts = append(parts, bytes)
}
if route != "" {
parts = append(parts, "route="+route)
}
if query != "" {
parts = append(parts, "query="+quoteIfNeeded(query))
}
if clientIP != "" && !isLocalClientIP(clientIP) {
parts = append(parts, "ip="+clientIP)
}
appendSortedFields(&parts, fields)
if entry.Error != "" {
parts = append(parts, "error="+quoteIfNeeded(entry.Error))
}
return strings.Join(parts, " ")
}
func appendSortedFields(parts *[]string, fields map[string]any) {
if len(fields) == 0 {
return
}
log.Print(string(bytes))
keys := make([]string, 0, len(fields))
for key := range fields {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
*parts = append(*parts, key+"="+formatFieldValue(fields[key]))
}
}
func cloneFields(fields map[string]any) map[string]any {
if len(fields) == 0 {
return nil
}
copyFields := make(map[string]any, len(fields))
for key, value := range fields {
copyFields[key] = value
}
return copyFields
}
func popField(fields map[string]any, key string) string {
if len(fields) == 0 {
return ""
}
value, ok := fields[key]
if !ok {
return ""
}
delete(fields, key)
return formatInlineField(key, value)
}
func formatInlineField(key string, value any) string {
switch key {
case "status":
return fmt.Sprint(value)
case "duration_ms":
return formatDurationMillis(value)
case "bytes":
return formatBytes(value)
default:
if text, ok := value.(string); ok {
return text
}
return fmt.Sprint(value)
}
}
func formatDurationMillis(value any) string {
ms, ok := toFloat64(value)
if !ok {
return fmt.Sprint(value)
}
return strconv.FormatFloat(ms, 'f', -1, 64) + "ms"
}
func formatBytes(value any) string {
bytesValue, ok := toFloat64(value)
if !ok {
return fmt.Sprint(value)
}
if bytesValue < 1024 {
return strconv.FormatFloat(bytesValue, 'f', -1, 64) + "B"
}
if bytesValue < 1024*1024 {
return strconv.FormatFloat(bytesValue/1024, 'f', 1, 64) + "KB"
}
return strconv.FormatFloat(bytesValue/(1024*1024), 'f', 1, 64) + "MB"
}
func toFloat64(value any) (float64, bool) {
switch v := value.(type) {
case int:
return float64(v), true
case int32:
return float64(v), true
case int64:
return float64(v), true
case float32:
return float64(v), true
case float64:
return v, true
default:
return 0, false
}
}
func isLocalClientIP(value string) bool {
parsed := net.ParseIP(value)
if parsed == nil {
return false
}
return parsed.IsLoopback()
}
func formatLogLevel(level LogLevel) string {
if colorLogs {
switch level {
case LogLevelWarn:
return ansiYellow + "WARN" + ansiReset
case LogLevelError:
return ansiRed + "ERROR" + ansiReset
default:
return ansiBlue + "INFO" + ansiReset
}
}
switch level {
case LogLevelWarn:
return "WARN"
case LogLevelError:
return "ERROR"
default:
return "INFO"
}
}
func shouldColorLogs() bool {
if strings.TrimSpace(os.Getenv("NO_COLOR")) != "" {
return false
}
if strings.EqualFold(strings.TrimSpace(os.Getenv("TERM")), "dumb") {
return false
}
info, err := os.Stderr.Stat()
if err != nil {
return false
}
return info.Mode()&os.ModeCharDevice != 0
}
func formatFieldValue(value any) string {
switch v := value.(type) {
case string:
return quoteIfNeeded(v)
case time.Duration:
return v.String()
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case fmt.Stringer:
return quoteIfNeeded(v.String())
default:
return quoteIfNeeded(fmt.Sprint(value))
}
}
func quoteIfNeeded(value string) string {
if value == "" {
return `""`
}
for _, r := range value {
if r == '=' || r == ' ' || r == '\t' || r == '\n' || r == '"' {
return strconv.Quote(value)
}
}
return value
}

View File

@@ -0,0 +1,36 @@
package observability
import (
"strings"
"testing"
)
func TestFormatLogEntryFormatsHTTPRequestCompactly(t *testing.T) {
line := formatLogEntry(LogEvent{
TS: "2026-06-11T12:57:39.557972Z",
Level: LogLevelInfo,
Event: "http_request",
Fields: map[string]any{
"bytes": 56198,
"client_ip": "127.0.0.1",
"duration_ms": 9.419,
"method": "GET",
"path": "/api/catalog/top-pick",
"status": 200,
},
})
checks := []string{
"2026-06-11T12:57:39.557972Z INFO http 200 GET /api/catalog/top-pick 9.419ms 54.9KB",
}
for _, check := range checks {
if !strings.Contains(line, check) {
t.Fatalf("line %q missing %q", line, check)
}
}
if strings.Contains(line, "client_ip=") {
t.Fatalf("line should omit loopback ip: %q", line)
}
}

View File

@@ -0,0 +1,32 @@
package observability
import "context"
type requestContextKey struct{}
type RequestContext struct {
ID string
Path string
Route string
}
func WithRequestContext(ctx context.Context, requestID string, path string, route string) context.Context {
if ctx == nil {
return nil
}
return context.WithValue(ctx, requestContextKey{}, RequestContext{
ID: requestID,
Path: path,
Route: route,
})
}
func RequestContextFromContext(ctx context.Context) (RequestContext, bool) {
if ctx == nil {
return RequestContext{}, false
}
requestContext, ok := ctx.Value(requestContextKey{}).(RequestContext)
return requestContext, ok
}