diff --git a/internal/audit/service_test.go b/internal/audit/service_test.go index 5bb93ac..c92f2c2 100644 --- a/internal/audit/service_test.go +++ b/internal/audit/service_test.go @@ -2,6 +2,7 @@ package audit_test import ( "context" + "database/sql" "encoding/json" "os" "testing" @@ -13,29 +14,9 @@ import ( ) func TestRecordInsertsAuditLog(t *testing.T) { - tmp, err := os.CreateTemp("", "mal-audit-*.db") - if err != nil { - t.Fatalf("CreateTemp: %v", err) - } - _ = tmp.Close() - t.Cleanup(func() { _ = os.Remove(tmp.Name()) }) - - sqlDB, err := db.Open(tmp.Name()) - if err != nil { - t.Fatalf("db.Open: %v", err) - } - t.Cleanup(func() { _ = sqlDB.Close() }) - - if err := database.RunMigrations(sqlDB); err != nil { - t.Fatalf("RunMigrations: %v", err) - } - - queries := db.New(sqlDB) - svc := audit.NewAuditService(queries) - - if _, err := sqlDB.Exec("INSERT INTO user (id, username, password_hash) VALUES (?, ?, ?)", "user-1", "test", "hash"); err != nil { - t.Fatalf("insert user: %v", err) - } + sqlDB := openTestDB(t) + svc := audit.NewAuditService(db.New(sqlDB)) + insertTestUser(t, sqlDB, "user-1") ctx := audit.WithRequestInfo(context.Background(), "127.0.0.1", "unit-test") metadata, err := json.Marshal(struct { @@ -55,7 +36,54 @@ func TestRecordInsertsAuditLog(t *testing.T) { t.Fatalf("Record: %v", err) } - rows, err := sqlDB.Query("SELECT action, resource_type, resource_id, ip, user_agent, metadata_json FROM audit_log WHERE user_id = ?", "user-1") + auditRow := queryAuditRow(t, sqlDB, "user-1") + assertAuditRow(t, auditRow) +} + +type auditRow struct { + action string + resourceType string + resourceID string + ip string + userAgent string + metadataJSON string +} + +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + + tmp, err := os.CreateTemp("", "mal-audit-*.db") + if err != nil { + t.Fatalf("CreateTemp: %v", err) + } + _ = tmp.Close() + t.Cleanup(func() { _ = os.Remove(tmp.Name()) }) + + sqlDB, err := db.Open(tmp.Name()) + if err != nil { + t.Fatalf("db.Open: %v", err) + } + t.Cleanup(func() { _ = sqlDB.Close() }) + + if err := database.RunMigrations(sqlDB); err != nil { + t.Fatalf("RunMigrations: %v", err) + } + + return sqlDB +} + +func insertTestUser(t *testing.T, sqlDB *sql.DB, userID string) { + t.Helper() + + if _, err := sqlDB.ExecContext(context.Background(), "INSERT INTO user (id, username, password_hash) VALUES (?, ?, ?)", userID, "test", "hash"); err != nil { + t.Fatalf("insert user: %v", err) + } +} + +func queryAuditRow(t *testing.T, sqlDB *sql.DB, userID string) auditRow { + t.Helper() + + rows, err := sqlDB.QueryContext(context.Background(), "SELECT action, resource_type, resource_id, ip, user_agent, metadata_json FROM audit_log WHERE user_id = ?", userID) if err != nil { t.Fatalf("Query: %v", err) } @@ -65,18 +93,24 @@ func TestRecordInsertsAuditLog(t *testing.T) { t.Fatalf("expected audit row") } - var action, resourceType, resourceID, ip, userAgent, metadataJSON string - if err := rows.Scan(&action, &resourceType, &resourceID, &ip, &userAgent, &metadataJSON); err != nil { + var row auditRow + if err := rows.Scan(&row.action, &row.resourceType, &row.resourceID, &row.ip, &row.userAgent, &row.metadataJSON); err != nil { t.Fatalf("Scan: %v", err) } - if action != "test_action" || resourceType != "thing" || resourceID != "123" { - t.Fatalf("unexpected row action=%q resourceType=%q resourceID=%q", action, resourceType, resourceID) + return row +} + +func assertAuditRow(t *testing.T, row auditRow) { + t.Helper() + + if row.action != "test_action" || row.resourceType != "thing" || row.resourceID != "123" { + t.Fatalf("unexpected row action=%q resourceType=%q resourceID=%q", row.action, row.resourceType, row.resourceID) } - if ip != "127.0.0.1" || userAgent != "unit-test" { - t.Fatalf("unexpected request info ip=%q userAgent=%q", ip, userAgent) + if row.ip != "127.0.0.1" || row.userAgent != "unit-test" { + t.Fatalf("unexpected request info ip=%q userAgent=%q", row.ip, row.userAgent) } - if metadataJSON == "" || metadataJSON == "null" { - t.Fatalf("expected metadata_json, got %q", metadataJSON) + if row.metadataJSON == "" || row.metadataJSON == "null" { + t.Fatalf("expected metadata_json, got %q", row.metadataJSON) } }