diff --git a/integrations/watchorder/watch_order_test.go b/integrations/watchorder/watch_order_test.go index 7c61b16..99696d3 100644 --- a/integrations/watchorder/watch_order_test.go +++ b/integrations/watchorder/watch_order_test.go @@ -3,20 +3,23 @@ package watchorder import ( "context" "errors" + "io" "net/http" - "net/http/httptest" "strings" "testing" "time" ) -func testServer(body string) *httptest.Server { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, _ = w.Write([]byte(body)) - }) - - return httptest.NewServer(handler) +func mockResponse(status int, headers map[string]string, body string) *http.Response { + h := make(http.Header, len(headers)) + for k, v := range headers { + h.Set(k, v) + } + return &http.Response{ + StatusCode: status, + Header: h, + Body: io.NopCloser(strings.NewReader(body)), + } } func testHTMLWithMetadata() string { @@ -55,11 +58,18 @@ func testHTMLEmptyRows() string { } func TestFetchWatchOrder_OutputShape(t *testing.T) { - server := testServer(testHTMLWithMetadata()) - defer server.Close() + client := &http.Client{ + Timeout: time.Second, + Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + if request.URL.RawQuery == "/tools/watch_order/id/442" { + return mockResponse(http.StatusOK, map[string]string{"Content-Type": "text/html; charset=utf-8"}, testHTMLWithMetadata()), nil + } + return mockResponse(http.StatusNotFound, nil, "not found"), nil + }), + } - url := server.URL + "/?/tools/watch_order/id/442" - result, err := FetchWatchOrder(context.Background(), &http.Client{Timeout: time.Second}, url) + url := "https://chiaki.site/?/tools/watch_order/id/442" + result, err := FetchWatchOrder(context.Background(), client, url) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -88,11 +98,18 @@ func TestFetchWatchOrder_OutputShape(t *testing.T) { } func TestFetchWatchOrder_NoRowsReturnsEmpty(t *testing.T) { - server := testServer(testHTMLEmptyRows()) - defer server.Close() + client := &http.Client{ + Timeout: time.Second, + Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + if request.URL.RawQuery == "/tools/watch_order/id/1535" { + return mockResponse(http.StatusOK, map[string]string{"Content-Type": "text/html; charset=utf-8"}, testHTMLEmptyRows()), nil + } + return mockResponse(http.StatusNotFound, nil, "not found"), nil + }), + } - url := server.URL + "/?/tools/watch_order/id/1535" - result, err := FetchWatchOrder(context.Background(), &http.Client{Timeout: time.Second}, url) + url := "https://chiaki.site/?/tools/watch_order/id/1535" + result, err := FetchWatchOrder(context.Background(), client, url) if err != nil { t.Fatalf("expected no error, got %v", err) } @@ -121,37 +138,18 @@ Jujutsu Kaisen 0 Dec 24, 2021 | Movie | 1ep × 1hr. 44min. | ★8.36 | [](https://myanimelist.net/anime/48561) ` - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.HasPrefix(r.URL.Path, "/http/") { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(proxyPayload)) - return - } - - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte("blocked")) - })) - defer server.Close() - - transport := http.DefaultTransport testClient := &http.Client{ Timeout: time.Second, Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { - if strings.HasPrefix(request.URL.Host, "r.jina.ai") { - proxyURL := server.URL + "/http/" + strings.TrimPrefix(request.URL.Path, "/") - proxyRequest, err := http.NewRequestWithContext(request.Context(), request.Method, proxyURL, nil) - if err != nil { - return nil, err - } - return transport.RoundTrip(proxyRequest) + switch { + case request.URL.Host == "chiaki.site": + return mockResponse(http.StatusForbidden, map[string]string{"Content-Type": "text/html; charset=utf-8"}, "blocked"), nil + case request.URL.Host == "r.jina.ai": + // Proxy response is plain text/markdown. + return mockResponse(http.StatusOK, map[string]string{"Content-Type": "text/plain; charset=utf-8"}, proxyPayload), nil + default: + return mockResponse(http.StatusNotFound, nil, "not found"), nil } - - blockedURL := server.URL + request.URL.Path - blockedRequest, err := http.NewRequestWithContext(request.Context(), request.Method, blockedURL, nil) - if err != nil { - return nil, err - } - return transport.RoundTrip(blockedRequest) }), } @@ -174,17 +172,19 @@ Jujutsu Kaisen 0 } func TestFetchWatchOrder_HTTPStatusErrorIncludesContext(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Server", "cloudflare") - w.Header().Set("CF-Ray", "abc123") - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(http.StatusForbidden) - _, _ = w.Write([]byte("access denied")) - })) - defer server.Close() + client := &http.Client{ + Timeout: time.Second, + Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + return mockResponse(http.StatusForbidden, map[string]string{ + "Server": "cloudflare", + "CF-Ray": "abc123", + "Content-Type": "text/html; charset=utf-8", + }, "access denied"), nil + }), + } - url := server.URL + "/?/tools/watch_order/id/1" - _, err := fetchDocument(context.Background(), &http.Client{Timeout: time.Second}, url) + url := "https://chiaki.site/?/tools/watch_order/id/1" + _, err := fetchDocument(context.Background(), client, url) if err == nil { t.Fatalf("expected error, got nil") } diff --git a/internal/server/server.go b/internal/server/server.go index 125f6ff..ebe38d4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -5,6 +5,7 @@ import ( "log" "net/http" "os" + "time" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/render" @@ -34,28 +35,39 @@ func RunServer(lifecycle fx.Lifecycle, r *gin.Engine) { port = "3000" } - srv := &http.Server{ - Addr: ":" + port, - Handler: r, - } + srv := newHTTPServer(":"+port, r) lifecycle.Append(fx.Hook{ OnStart: func(context.Context) error { log.Printf("Starting server on http://localhost:%s", port) go func() { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - log.Fatalf("listen: %s\n", err) + // Avoid exiting the process from a goroutine; let the process supervisor handle restarts. + log.Printf("server listen error: %s", err) } }() return nil }, OnStop: func(ctx context.Context) error { log.Println("Shutting down server...") + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() return srv.Shutdown(ctx) }, }) } +func newHTTPServer(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 2 * time.Minute, + } +} + // RouteRegister is an interface that modules can implement to register their routes. type RouteRegister interface { Register(r *gin.Engine) diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..ec923c4 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,28 @@ +package server + +import ( + "net/http" + "testing" + "time" +) + +func TestNewHTTPServer_TimeoutsAndAddr(t *testing.T) { + srv := newHTTPServer(":1234", http.NewServeMux()) + + if srv.Addr != ":1234" { + t.Fatalf("Addr: got %q want %q", srv.Addr, ":1234") + } + if srv.ReadHeaderTimeout != 5*time.Second { + t.Fatalf("ReadHeaderTimeout: got %s want %s", srv.ReadHeaderTimeout, 5*time.Second) + } + if srv.ReadTimeout != 30*time.Second { + t.Fatalf("ReadTimeout: got %s want %s", srv.ReadTimeout, 30*time.Second) + } + if srv.WriteTimeout != 30*time.Second { + t.Fatalf("WriteTimeout: got %s want %s", srv.WriteTimeout, 30*time.Second) + } + if srv.IdleTimeout != 2*time.Minute { + t.Fatalf("IdleTimeout: got %s want %s", srv.IdleTimeout, 2*time.Minute) + } +} +