diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 3d7adb6a6..64761d728 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -70,6 +70,8 @@ type responseContextKey struct{} type ResponseContext struct { StatusCode int + Request *http.Request + Response *http.Response } func InjectResponseContext(ctx context.Context) (context.Context, *ResponseContext) { @@ -77,9 +79,11 @@ func InjectResponseContext(ctx context.Context) (context.Context, *ResponseConte return context.WithValue(ctx, responseContextKey{}, value), value } -func setResponseStatusCode(ctx context.Context, statusCode int) { +func setResponseStatus(ctx context.Context, request *http.Request, response *http.Response) { if value, ok := ctx.Value(responseContextKey{}).(*ResponseContext); ok { - value.StatusCode = statusCode + value.StatusCode = response.StatusCode + value.Request = request + value.Response = response } } @@ -191,7 +195,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head } defer response.Body.Close() - setResponseStatusCode(ctx, response.StatusCode) + setResponseStatus(ctx, request, response) respReader, err := respBodyReader(response) if err != nil { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 18788ee02..44e1b2047 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -7,6 +7,7 @@ import ( "encoding/json" goerrors "errors" "fmt" + "net/http" "net/http/httptrace" "slices" "strconv" @@ -37,7 +38,73 @@ type LoaderHooks interface { // OnLoad is called before the fetch is executed OnLoad(ctx context.Context, ds DataSourceInfo) context.Context // OnFinished is called after the fetch has been executed and the response has been processed and merged - OnFinished(ctx context.Context, statusCode int, ds DataSourceInfo, err error) + OnFinished(ctx context.Context, ds DataSourceInfo, info *ResponseInfo) +} + +type DataSourceInfo struct { + ID string + Name string +} + +type ResponseInfo struct { + StatusCode int + Err error + // Request is the original request that was sent to the subgraph. This should only be used for reading purposes, + // in order to ensure there aren't memory conflicts, and the body will be nil, as it was read already. + Request *http.Request + // ResponseHeaders contains a clone of the headers of the response from the subgraph. + ResponseHeaders http.Header +} + +func newResponseInfo(res *result, subgraphError error) *ResponseInfo { + responseInfo := &ResponseInfo{StatusCode: res.statusCode, Err: goerrors.Join(res.err, subgraphError)} + if res.httpResponseContext != nil { + // We're using the response.Request here, because the body will be nil (since the response was read) and won't + // cause a memory leak. + if res.httpResponseContext.Response != nil { + responseInfo.Request = res.httpResponseContext.Response.Request + responseInfo.ResponseHeaders = res.httpResponseContext.Response.Header.Clone() + } else { + // In cases where the request errors, the response will be nil, and so we need to get the original request + responseInfo.Request = res.httpResponseContext.Request + } + } + + return responseInfo +} + +type result struct { + postProcessing PostProcessingConfiguration + out *bytes.Buffer + batchStats [][]int + fetchSkipped bool + nestedMergeItems []*result + + statusCode int + err error + ds DataSourceInfo + + authorizationRejected bool + authorizationRejectedReasons []string + + rateLimitRejected bool + rateLimitRejectedReason string + + // loaderHookContext used to share data between the OnLoad and OnFinished hooks + // It should be valid even when OnLoad isn't called + loaderHookContext context.Context + + httpResponseContext *httpclient.ResponseContext +} + +func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { + r.postProcessing = postProcessing + if info != nil { + r.ds = DataSourceInfo{ + ID: info.DataSourceID, + Name: info.DataSourceName, + } + } } func IsIntrospectionDataSource(dataSourceID string) bool { @@ -118,7 +185,9 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { for j := range results[i].nestedMergeItems { err = l.mergeResult(nodes[i].Item, results[i].nestedMergeItems[j], itemsItems[i][j:j+1]) if l.ctx.LoaderHooks != nil && results[i].nestedMergeItems[j].loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(results[i].nestedMergeItems[j].loaderHookContext, results[i].nestedMergeItems[j].statusCode, results[i].nestedMergeItems[j].ds, goerrors.Join(results[i].nestedMergeItems[j].err, l.ctx.subgraphErrors)) + l.ctx.LoaderHooks.OnFinished(results[i].nestedMergeItems[j].loaderHookContext, + results[i].nestedMergeItems[j].ds, + newResponseInfo(results[i].nestedMergeItems[j], l.ctx.subgraphErrors)) } if err != nil { return errors.WithStack(err) @@ -126,8 +195,8 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { } } else { err = l.mergeResult(nodes[i].Item, results[i], itemsItems[i]) - if l.ctx.LoaderHooks != nil && results[i].loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].statusCode, results[i].ds, goerrors.Join(results[i].err, l.ctx.subgraphErrors)) + if l.ctx.LoaderHooks != nil { + l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) } if err != nil { return errors.WithStack(err) @@ -162,9 +231,10 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err } err = l.mergeResult(item, res, items) - if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors)) + if l.ctx.LoaderHooks != nil { + l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } + return err case *BatchEntityFetch: res := &result{ @@ -175,8 +245,8 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return errors.WithStack(err) } err = l.mergeResult(item, res, items) - if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors)) + if l.ctx.LoaderHooks != nil { + l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err case *EntityFetch: @@ -188,8 +258,8 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return errors.WithStack(err) } err = l.mergeResult(item, res, items) - if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors)) + if l.ctx.LoaderHooks != nil { + l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) } return err case *ParallelListItemFetch: @@ -221,8 +291,8 @@ func (l *Loader) resolveSingle(item *FetchItem) error { } for i := range results { err = l.mergeResult(item, results[i], items[i:i+1]) - if l.ctx.LoaderHooks != nil && results[i].loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].statusCode, results[i].ds, goerrors.Join(results[i].err, l.ctx.subgraphErrors)) + if l.ctx.LoaderHooks != nil { + l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) } if err != nil { return errors.WithStack(err) @@ -527,43 +597,6 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson return nil } -type DataSourceInfo struct { - ID string - Name string -} - -type result struct { - postProcessing PostProcessingConfiguration - out *bytes.Buffer - batchStats [][]int - fetchSkipped bool - nestedMergeItems []*result - - statusCode int - err error - ds DataSourceInfo - - authorizationRejected bool - authorizationRejectedReasons []string - - rateLimitRejected bool - rateLimitRejectedReason string - - // loaderHookContext used to share data between the OnLoad and OnFinished hooks - // Only set when the OnLoad is called - loaderHookContext context.Context -} - -func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) { - r.postProcessing = postProcessing - if info != nil { - r.ds = DataSourceInfo{ - ID: info.DataSourceID, - Name: info.DataSourceName, - } - } -} - var ( errorsInvalidInputHeader = []byte(`{"errors":[{"message":"Failed to render Fetch Input","path":[`) errorsInvalidInputFooter = []byte(`]}]}`) @@ -1526,6 +1559,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so res.err = l.loadByContext(res.loaderHookContext, source, input, res) } else { res.err = l.loadByContext(ctx, source, input, res) + res.loaderHookContext = ctx // Set the context to the original context to ensure that OnFinished hook gets valid context } } else { @@ -1533,6 +1567,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so } res.statusCode = responseContext.StatusCode + res.httpResponseContext = responseContext if l.ctx.TracingOptions.Enable { stats := GetSingleFlightStats(ctx) diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index d77c61455..21aea561f 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -35,13 +35,13 @@ func (f *TestLoaderHooks) OnLoad(ctx context.Context, ds DataSourceInfo) context return ctx } -func (f *TestLoaderHooks) OnFinished(ctx context.Context, statusCode int, ds DataSourceInfo, err error) { +func (f *TestLoaderHooks) OnFinished(ctx context.Context, ds DataSourceInfo, responseInfo *ResponseInfo) { f.postFetchCalls.Add(1) f.mu.Lock() defer f.mu.Unlock() - f.errors = append(f.errors, err) + f.errors = append(f.errors, responseInfo.Err) } func TestLoaderHooks_FetchPipeline(t *testing.T) {