From dfe3c6b7d8ab38e1652deae28e0e00156978f894 Mon Sep 17 00:00:00 2001 From: mkelvers Date: Tue, 26 May 2026 16:14:20 +0200 Subject: [PATCH] feat: add audit service and request context middleware --- internal/audit/middleware.go | 30 ++++++++++ internal/audit/module.go | 12 ++++ internal/audit/service/service.go | 73 ++++++++++++++++++++++ internal/audit/service/service_test.go | 83 ++++++++++++++++++++++++++ 4 files changed, 198 insertions(+) create mode 100644 internal/audit/middleware.go create mode 100644 internal/audit/module.go create mode 100644 internal/audit/service/service.go create mode 100644 internal/audit/service/service_test.go diff --git a/internal/audit/middleware.go b/internal/audit/middleware.go new file mode 100644 index 0000000..3174862 --- /dev/null +++ b/internal/audit/middleware.go @@ -0,0 +1,30 @@ +package audit + +import ( + "net" + "strings" + + "github.com/gin-gonic/gin" + "mal/internal/auditctx" +) + +func ContextMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + ip := clientIP(c.ClientIP()) + userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) + c.Request = c.Request.WithContext(auditctx.WithRequestInfo(c.Request.Context(), ip, userAgent)) + c.Next() + } +} + +func clientIP(ip string) string { + trimmed := strings.TrimSpace(ip) + if trimmed == "" { + return "" + } + parsed := net.ParseIP(trimmed) + if parsed == nil { + return trimmed + } + return parsed.String() +} diff --git a/internal/audit/module.go b/internal/audit/module.go new file mode 100644 index 0000000..d9a71f6 --- /dev/null +++ b/internal/audit/module.go @@ -0,0 +1,12 @@ +package audit + +import ( + "mal/internal/audit/service" + + "go.uber.org/fx" +) + +var Module = fx.Options( + fx.Provide(service.NewAuditService), +) + diff --git a/internal/audit/service/service.go b/internal/audit/service/service.go new file mode 100644 index 0000000..2c58258 --- /dev/null +++ b/internal/audit/service/service.go @@ -0,0 +1,73 @@ +package service + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "mal/internal/auditctx" + "mal/internal/db" + "mal/internal/domain" + "mal/internal/observability" + "strings" + + "github.com/google/uuid" +) + +type auditService struct { + queries *db.Queries +} + +func NewAuditService(queries *db.Queries) domain.AuditService { + return &auditService{queries: queries} +} + +func (s *auditService) Record(ctx context.Context, event domain.AuditEvent) error { + if s == nil || s.queries == nil { + return errors.New("audit service not configured") + } + action := strings.TrimSpace(event.Action) + if action == "" { + return errors.New("audit action missing") + } + + ip, userAgent := auditctx.RequestInfoFromContext(ctx) + if strings.TrimSpace(event.IP) != "" { + ip = event.IP + } + if strings.TrimSpace(event.UserAgent) != "" { + userAgent = event.UserAgent + } + + metadataJSON := event.MetadataJSON + if len(metadataJSON) == 0 { + metadataJSON = json.RawMessage("null") + } + + _, err := s.queries.CreateAuditLog(ctx, db.CreateAuditLogParams{ + ID: uuid.New().String(), + UserID: sql.NullString{String: strings.TrimSpace(event.UserID), Valid: strings.TrimSpace(event.UserID) != ""}, + Action: action, + ResourceType: sql.NullString{String: strings.TrimSpace(event.ResourceType), Valid: strings.TrimSpace(event.ResourceType) != ""}, + ResourceID: sql.NullString{String: strings.TrimSpace(event.ResourceID), Valid: strings.TrimSpace(event.ResourceID) != ""}, + Ip: sql.NullString{String: strings.TrimSpace(ip), Valid: strings.TrimSpace(ip) != ""}, + UserAgent: sql.NullString{String: strings.TrimSpace(userAgent), Valid: strings.TrimSpace(userAgent) != ""}, + MetadataJson: sql.NullString{String: string(metadataJSON), Valid: true}, + }) + if err != nil { + return err + } + + observability.Info( + "audit", + "audit", + action, + map[string]any{ + "user_id": event.UserID, + "resource_type": event.ResourceType, + "resource_id": event.ResourceID, + }, + ) + + return nil +} diff --git a/internal/audit/service/service_test.go b/internal/audit/service/service_test.go new file mode 100644 index 0000000..18b3efc --- /dev/null +++ b/internal/audit/service/service_test.go @@ -0,0 +1,83 @@ +package service_test + +import ( + "context" + "encoding/json" + "os" + "testing" + + "mal/internal/auditctx" + "mal/internal/audit/service" + "mal/internal/database" + "mal/internal/db" + "mal/internal/domain" +) + +func TestRecordInsertsAuditLog(t *testing.T) { + tmp, err := os.CreateTemp("", "mal-audit-*.db") + if err != nil { + t.Fatalf("CreateTemp: %v", err) + } + _ = tmp.Close() + t.Cleanup(func() { _ = os.Remove(tmp.Name()) }) + + sqlDB, err := db.Open(tmp.Name()) + if err != nil { + t.Fatalf("db.Open: %v", err) + } + t.Cleanup(func() { _ = sqlDB.Close() }) + + if err := database.RunMigrations(sqlDB); err != nil { + t.Fatalf("RunMigrations: %v", err) + } + + queries := db.New(sqlDB) + svc := service.NewAuditService(queries) + + if _, err := sqlDB.Exec("INSERT INTO user (id, username, password_hash) VALUES (?, ?, ?)", "user-1", "test", "hash"); err != nil { + t.Fatalf("insert user: %v", err) + } + + ctx := auditctx.WithRequestInfo(context.Background(), "127.0.0.1", "unit-test") + metadata, err := json.Marshal(struct { + Foo string `json:"foo"` + }{Foo: "bar"}) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + + if err := svc.Record(ctx, domain.AuditEvent{ + UserID: "user-1", + Action: "test_action", + ResourceType: "thing", + ResourceID: "123", + MetadataJSON: metadata, + }); err != nil { + t.Fatalf("Record: %v", err) + } + + rows, err := sqlDB.Query("SELECT action, resource_type, resource_id, ip, user_agent, metadata_json FROM audit_log WHERE user_id = ?", "user-1") + if err != nil { + t.Fatalf("Query: %v", err) + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + t.Fatalf("expected audit row") + } + + var action, resourceType, resourceID, ip, userAgent, metadataJSON string + if err := rows.Scan(&action, &resourceType, &resourceID, &ip, &userAgent, &metadataJSON); err != nil { + t.Fatalf("Scan: %v", err) + } + + if action != "test_action" || resourceType != "thing" || resourceID != "123" { + t.Fatalf("unexpected row action=%q resourceType=%q resourceID=%q", action, resourceType, resourceID) + } + if ip != "127.0.0.1" || userAgent != "unit-test" { + t.Fatalf("unexpected request info ip=%q userAgent=%q", ip, userAgent) + } + if metadataJSON == "" || metadataJSON == "null" { + t.Fatalf("expected metadata_json, got %q", metadataJSON) + } +}