Skip to content

Commit

Permalink
feat: add http datasource onfinished hook (#1001)
Browse files Browse the repository at this point in the history
As part of our efforts in
wundergraph/cosmo#1401, we need to have access
to the request and response from the calls to subgraphs. As a result,
we're exposing a new hook which is specialized for a HTTP datasource
```
OnHttpFinished(ctx context.Context, ds DataSourceInfo, err error, request *http.Request, response *http.Response)
```

which users can use to gain access to the sanitized request/response and
properly handle that information
  • Loading branch information
df-wg authored Dec 2, 2024
1 parent 4018291 commit 5d14a22
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 54 deletions.
10 changes: 7 additions & 3 deletions v2/pkg/engine/datasource/httpclient/nethttpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,20 @@ type responseContextKey struct{}

type ResponseContext struct {
StatusCode int
Request *http.Request
Response *http.Response
}

func InjectResponseContext(ctx context.Context) (context.Context, *ResponseContext) {
value := &ResponseContext{}
return context.WithValue(ctx, responseContextKey{}, value), value
}

func setResponseStatusCode(ctx context.Context, statusCode int) {
func setResponseStatus(ctx context.Context, request *http.Request, response *http.Response) {
if value, ok := ctx.Value(responseContextKey{}).(*ResponseContext); ok {
value.StatusCode = statusCode
value.StatusCode = response.StatusCode
value.Request = request
value.Response = response
}
}

Expand Down Expand Up @@ -191,7 +195,7 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head
}
defer response.Body.Close()

setResponseStatusCode(ctx, response.StatusCode)
setResponseStatus(ctx, request, response)

respReader, err := respBodyReader(response)
if err != nil {
Expand Down
133 changes: 84 additions & 49 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
goerrors "errors"
"fmt"
"net/http"
"net/http/httptrace"
"slices"
"strconv"
Expand Down Expand Up @@ -37,7 +38,73 @@ type LoaderHooks interface {
// OnLoad is called before the fetch is executed
OnLoad(ctx context.Context, ds DataSourceInfo) context.Context
// OnFinished is called after the fetch has been executed and the response has been processed and merged
OnFinished(ctx context.Context, statusCode int, ds DataSourceInfo, err error)
OnFinished(ctx context.Context, ds DataSourceInfo, info *ResponseInfo)
}

type DataSourceInfo struct {
ID string
Name string
}

type ResponseInfo struct {
StatusCode int
Err error
// Request is the original request that was sent to the subgraph. This should only be used for reading purposes,
// in order to ensure there aren't memory conflicts, and the body will be nil, as it was read already.
Request *http.Request
// ResponseHeaders contains a clone of the headers of the response from the subgraph.
ResponseHeaders http.Header
}

func newResponseInfo(res *result, subgraphError error) *ResponseInfo {
responseInfo := &ResponseInfo{StatusCode: res.statusCode, Err: goerrors.Join(res.err, subgraphError)}
if res.httpResponseContext != nil {
// We're using the response.Request here, because the body will be nil (since the response was read) and won't
// cause a memory leak.
if res.httpResponseContext.Response != nil {
responseInfo.Request = res.httpResponseContext.Response.Request
responseInfo.ResponseHeaders = res.httpResponseContext.Response.Header.Clone()
} else {
// In cases where the request errors, the response will be nil, and so we need to get the original request
responseInfo.Request = res.httpResponseContext.Request
}
}

return responseInfo
}

type result struct {
postProcessing PostProcessingConfiguration
out *bytes.Buffer
batchStats [][]int
fetchSkipped bool
nestedMergeItems []*result

statusCode int
err error
ds DataSourceInfo

authorizationRejected bool
authorizationRejectedReasons []string

rateLimitRejected bool
rateLimitRejectedReason string

// loaderHookContext used to share data between the OnLoad and OnFinished hooks
// It should be valid even when OnLoad isn't called
loaderHookContext context.Context

httpResponseContext *httpclient.ResponseContext
}

func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) {
r.postProcessing = postProcessing
if info != nil {
r.ds = DataSourceInfo{
ID: info.DataSourceID,
Name: info.DataSourceName,
}
}
}

func IsIntrospectionDataSource(dataSourceID string) bool {
Expand Down Expand Up @@ -118,16 +185,18 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error {
for j := range results[i].nestedMergeItems {
err = l.mergeResult(nodes[i].Item, results[i].nestedMergeItems[j], itemsItems[i][j:j+1])
if l.ctx.LoaderHooks != nil && results[i].nestedMergeItems[j].loaderHookContext != nil {
l.ctx.LoaderHooks.OnFinished(results[i].nestedMergeItems[j].loaderHookContext, results[i].nestedMergeItems[j].statusCode, results[i].nestedMergeItems[j].ds, goerrors.Join(results[i].nestedMergeItems[j].err, l.ctx.subgraphErrors))
l.ctx.LoaderHooks.OnFinished(results[i].nestedMergeItems[j].loaderHookContext,
results[i].nestedMergeItems[j].ds,
newResponseInfo(results[i].nestedMergeItems[j], l.ctx.subgraphErrors))
}
if err != nil {
return errors.WithStack(err)
}
}
} else {
err = l.mergeResult(nodes[i].Item, results[i], itemsItems[i])
if l.ctx.LoaderHooks != nil && results[i].loaderHookContext != nil {
l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].statusCode, results[i].ds, goerrors.Join(results[i].err, l.ctx.subgraphErrors))
if l.ctx.LoaderHooks != nil {
l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors))
}
if err != nil {
return errors.WithStack(err)
Expand Down Expand Up @@ -162,9 +231,10 @@ func (l *Loader) resolveSingle(item *FetchItem) error {
return err
}
err = l.mergeResult(item, res, items)
if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil {
l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors))
if l.ctx.LoaderHooks != nil {
l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors))
}

return err
case *BatchEntityFetch:
res := &result{
Expand All @@ -175,8 +245,8 @@ func (l *Loader) resolveSingle(item *FetchItem) error {
return errors.WithStack(err)
}
err = l.mergeResult(item, res, items)
if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil {
l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors))
if l.ctx.LoaderHooks != nil {
l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors))
}
return err
case *EntityFetch:
Expand All @@ -188,8 +258,8 @@ func (l *Loader) resolveSingle(item *FetchItem) error {
return errors.WithStack(err)
}
err = l.mergeResult(item, res, items)
if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil {
l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.statusCode, res.ds, goerrors.Join(res.err, l.ctx.subgraphErrors))
if l.ctx.LoaderHooks != nil {
l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors))
}
return err
case *ParallelListItemFetch:
Expand Down Expand Up @@ -221,8 +291,8 @@ func (l *Loader) resolveSingle(item *FetchItem) error {
}
for i := range results {
err = l.mergeResult(item, results[i], items[i:i+1])
if l.ctx.LoaderHooks != nil && results[i].loaderHookContext != nil {
l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].statusCode, results[i].ds, goerrors.Join(results[i].err, l.ctx.subgraphErrors))
if l.ctx.LoaderHooks != nil {
l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors))
}
if err != nil {
return errors.WithStack(err)
Expand Down Expand Up @@ -527,43 +597,6 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
return nil
}

type DataSourceInfo struct {
ID string
Name string
}

type result struct {
postProcessing PostProcessingConfiguration
out *bytes.Buffer
batchStats [][]int
fetchSkipped bool
nestedMergeItems []*result

statusCode int
err error
ds DataSourceInfo

authorizationRejected bool
authorizationRejectedReasons []string

rateLimitRejected bool
rateLimitRejectedReason string

// loaderHookContext used to share data between the OnLoad and OnFinished hooks
// Only set when the OnLoad is called
loaderHookContext context.Context
}

func (r *result) init(postProcessing PostProcessingConfiguration, info *FetchInfo) {
r.postProcessing = postProcessing
if info != nil {
r.ds = DataSourceInfo{
ID: info.DataSourceID,
Name: info.DataSourceName,
}
}
}

var (
errorsInvalidInputHeader = []byte(`{"errors":[{"message":"Failed to render Fetch Input","path":[`)
errorsInvalidInputFooter = []byte(`]}]}`)
Expand Down Expand Up @@ -1526,13 +1559,15 @@ func (l *Loader) executeSourceLoad(ctx context.Context, fetchItem *FetchItem, so
res.err = l.loadByContext(res.loaderHookContext, source, input, res)
} else {
res.err = l.loadByContext(ctx, source, input, res)
res.loaderHookContext = ctx // Set the context to the original context to ensure that OnFinished hook gets valid context
}

} else {
res.err = l.loadByContext(ctx, source, input, res)
}

res.statusCode = responseContext.StatusCode
res.httpResponseContext = responseContext

if l.ctx.TracingOptions.Enable {
stats := GetSingleFlightStats(ctx)
Expand Down
4 changes: 2 additions & 2 deletions v2/pkg/engine/resolve/loader_hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ func (f *TestLoaderHooks) OnLoad(ctx context.Context, ds DataSourceInfo) context
return ctx
}

func (f *TestLoaderHooks) OnFinished(ctx context.Context, statusCode int, ds DataSourceInfo, err error) {
func (f *TestLoaderHooks) OnFinished(ctx context.Context, ds DataSourceInfo, responseInfo *ResponseInfo) {
f.postFetchCalls.Add(1)

f.mu.Lock()
defer f.mu.Unlock()

f.errors = append(f.errors, err)
f.errors = append(f.errors, responseInfo.Err)
}

func TestLoaderHooks_FetchPipeline(t *testing.T) {
Expand Down

0 comments on commit 5d14a22

Please sign in to comment.