diff --git a/internal/server/cors.go b/internal/server/cors.go new file mode 100644 index 0000000..02e716e --- /dev/null +++ b/internal/server/cors.go @@ -0,0 +1,45 @@ +package server + +import ( + "net/http" + "os" + "strings" + + "github.com/gin-gonic/gin" +) + +func CORSMiddleware() gin.HandlerFunc { + allowAll := os.Getenv("MAL_CORS_ALLOW_ALL") == "1" + + return func(c *gin.Context) { + origin := c.GetHeader("Origin") + if origin != "" && (allowAll || isAllowedOrigin(origin)) { + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Vary", "Origin") + c.Header("Access-Control-Allow-Methods", "GET,POST,DELETE,OPTIONS") + c.Header("Access-Control-Allow-Headers", "Authorization,Content-Type") + c.Header("Access-Control-Max-Age", "600") + } + + if c.Request.Method == http.MethodOptions && strings.HasPrefix(c.Request.URL.Path, "/api/") { + c.Status(http.StatusNoContent) + c.Abort() + return + } + + c.Next() + } +} + +func isAllowedOrigin(origin string) bool { + if strings.HasPrefix(origin, "moz-extension://") { + return true + } + if strings.HasPrefix(origin, "http://localhost:") || strings.HasPrefix(origin, "https://localhost:") { + return true + } + if strings.HasPrefix(origin, "http://127.0.0.1:") || strings.HasPrefix(origin, "https://127.0.0.1:") { + return true + } + return false +} diff --git a/internal/server/server.go b/internal/server/server.go index ebe38d4..341a9ff 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -22,7 +22,7 @@ func ProvideRouter(htmlRender render.HTMLRender) *gin.Engine { gin.SetMode(gin.ReleaseMode) } r := gin.New() - r.Use(gin.Logger(), gin.Recovery()) + r.Use(CORSMiddleware(), gin.Logger(), gin.Recovery()) r.Static("/static", "./static") r.Static("/dist", "./dist") r.HTMLRender = htmlRender