diff --git a/pkg/net/document.go b/pkg/net/document.go index 8d664e3..b063e22 100644 --- a/pkg/net/document.go +++ b/pkg/net/document.go @@ -9,6 +9,16 @@ import ( "github.com/PuerkitoBio/goquery" ) +func responseURL(response *http.Response, fallbackRequest *http.Request) string { + if response != nil && response.Request != nil && response.Request.URL != nil { + return response.Request.URL.String() + } + if fallbackRequest != nil && fallbackRequest.URL != nil { + return fallbackRequest.URL.String() + } + return "" +} + func FetchHTMLDocument( ctx context.Context, httpClient *http.Client, @@ -37,13 +47,13 @@ func FetchHTMLDocument( if response.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(response.Body, Bytes512)) - return nil, response.Request.URL.String(), buildStatusError(response, body) + return nil, responseURL(response, request), buildStatusError(response, body) } document, err := goquery.NewDocumentFromReader(response.Body) if err != nil { - return nil, response.Request.URL.String(), fmt.Errorf("failed to parse html: %w", err) + return nil, responseURL(response, request), fmt.Errorf("failed to parse html: %w", err) } - return document, response.Request.URL.String(), nil + return document, responseURL(response, request), nil } diff --git a/pkg/net/document_test.go b/pkg/net/document_test.go new file mode 100644 index 0000000..aef89cb --- /dev/null +++ b/pkg/net/document_test.go @@ -0,0 +1,41 @@ +package netutil + +import ( + "context" + "io" + "net/http" + "strings" + "testing" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(request *http.Request) (*http.Response, error) { + return f(request) +} + +func TestFetchHTMLDocumentFallsBackToOriginalURLWhenResponseRequestMissing(t *testing.T) { + client := &http.Client{ + Transport: roundTripFunc(func(request *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("
ok
")), + }, nil + }), + } + + url := "https://example.test/watch-order" + document, finalURL, err := FetchHTMLDocument(context.Background(), client, url, nil, nil) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if finalURL != url { + t.Fatalf("expected final url %q, got %q", url, finalURL) + } + + if got := strings.TrimSpace(document.Find("main").Text()); got != "ok" { + t.Fatalf("expected document text ok, got %q", got) + } +}