Skip to content

Commit

Permalink
fix: redact headers in trace input (wundergraph#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
pvormste committed Feb 26, 2024
1 parent 920806a commit 50349c2
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 1 deletion.
20 changes: 20 additions & 0 deletions v2/pkg/engine/datasource/httpclient/httpclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"compress/gzip"
"context"
"github.com/tidwall/sjson"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -219,4 +220,23 @@ func TestHttpClientDo(t *testing.T) {
})
})

t.Run("redact sensitive headers", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := httputil.DumpRequest(r, true)
assert.NoError(t, err)
w.Header().Set("Authorization", "test")
_, err = w.Write([]byte(`{"extensions": {"trace": {}}"}`))
assert.NoError(t, err)
}))
defer server.Close()
var input []byte
input = SetInputMethod(input, []byte("GET"))
input = SetInputURL(input, []byte(server.URL))
input, err := sjson.SetBytes(input, TRACE, true)
assert.NoError(t, err)
out := &bytes.Buffer{}
err = Do(http.DefaultClient, context.Background(), input, out)
assert.NoError(t, err)
assert.Contains(t, out.String(), `"Authorization":["****"]`)
})
}
3 changes: 2 additions & 1 deletion v2/pkg/engine/datasource/httpclient/nethttpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"golang.org/x/exp/slices"
"io"
"net/http"
"strings"
"time"

"github.com/andybalholm/brotli"
Expand Down Expand Up @@ -183,7 +184,7 @@ var headersToRedact = []string{
func redactHeaders(headers http.Header) http.Header {
redactedHeaders := make(http.Header)
for key, values := range headers {
if slices.Contains(headersToRedact, key) {
if slices.Contains(headersToRedact, strings.ToLower(key)) {
redactedHeaders[key] = []string{"****"}
} else {
redactedHeaders[key] = values
Expand Down
44 changes: 44 additions & 0 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"golang.org/x/exp/slices"
"io"
"net/http/httptrace"
"runtime"
Expand Down Expand Up @@ -718,6 +720,43 @@ WithNextItem:
return nil
}

func redactHeaders(rawJSON json.RawMessage) (json.RawMessage, error) {
var obj map[string]interface{}

sensitiveHeaders := []string{
"authorization",
"www-authenticate",
"proxy-authenticate",
"proxy-authorization",
"cookie",
"set-cookie",
}

err := json.Unmarshal(rawJSON, &obj)
if err != nil {
return nil, err
}

if headers, ok := obj["header"]; ok {
if headerMap, isMap := headers.(map[string]interface{}); isMap {
for key, values := range headerMap {
if slices.Contains(sensitiveHeaders, strings.ToLower(key)) {
headerMap[key] = []string{"****"}
} else {
headerMap[key] = values
}
}
}
}

redactedJSON, err := json.Marshal(obj)
if err != nil {
return nil, err
}

return json.RawMessage(redactedJSON), nil
}

func (l *Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight bool, source DataSource, input []byte, out io.Writer, trace *DataSourceLoadTrace) (err error) {
if l.ctx.Extensions != nil {
input, err = jsonparser.Set(input, l.ctx.Extensions, "body", "extensions")
Expand All @@ -730,6 +769,11 @@ func (l *Loader) executeSourceLoad(ctx context.Context, disallowSingleFlight boo
if !l.traceOptions.ExcludeInput {
trace.Input = make([]byte, len(input))
copy(trace.Input, input) // copy input explicitly, omit __trace__ field
redactedInput, err := redactHeaders(trace.Input)
if err != nil {
return err
}
trace.Input = redactedInput
}
if gjson.ValidBytes(input) {
inputCopy := make([]byte, len(input))
Expand Down
107 changes: 107 additions & 0 deletions v2/pkg/engine/resolve/loader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package resolve
import (
"bytes"
"context"
"encoding/json"
"net/http"
"testing"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -881,3 +883,108 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) {
}
}
}

func TestLoader_RedactHeaders(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

productsService := mockedDS(t, ctrl,
`{"method":"POST","url":"http://products","header":{"Authorization":"value"},"body":{"query":"query{topProducts{name __typename upc}}"},"__trace__":true}`,
`{"topProducts":[{"name":"Table","__typename":"Product","upc":"1"},{"name":"Couch","__typename":"Product","upc":"2"},{"name":"Chair","__typename":"Product","upc":"3"}]}`)

response := &GraphQLResponse{
Data: &Object{
Fetch: &SingleFetch{
InputTemplate: InputTemplate{
Segments: []TemplateSegment{
{
Data: []byte(`{"method":"POST","url":"http://products","header":{"Authorization":"`),
SegmentType: StaticSegmentType,
},
{
SegmentType: VariableSegmentType,
VariableKind: HeaderVariableKind,
VariableSourcePath: []string{"Authorization"},
},
{
Data: []byte(`"},"body":{"query":"query{topProducts{name __typename upc}}"},"__trace__":true}`),
SegmentType: StaticSegmentType,
},
},
},
FetchConfiguration: FetchConfiguration{
DataSource: productsService,
PostProcessing: PostProcessingConfiguration{
SelectResponseDataPath: []string{"data"},
},
},
},
Fields: []*Field{
{
Name: []byte("topProducts"),
Value: &Array{
Path: []string{"topProducts"},
Item: &Object{
Fields: []*Field{
{
Name: []byte("name"),
Value: &String{
Path: []string{"name"},
},
},
{
Name: []byte("__typename"),
Value: &String{
Path: []string{"__typename"},
},
},
{
Name: []byte("upc"),
Value: &String{
Path: []string{"upc"},
},
},
},
},
},
},
},
},
}

ctx := &Context{
ctx: context.Background(),
Request: Request{
Header: http.Header{"Authorization": []string{"value"}},
},
}
resolvable := &Resolvable{
storage: &astjson.JSON{},
requestTraceOptions: RequestTraceOptions{Enable: true},
}
loader := &Loader{}

err := resolvable.Init(ctx, nil, ast.OperationTypeQuery)
assert.NoError(t, err)

err = loader.LoadGraphQLResponseData(ctx, response, resolvable)
assert.NoError(t, err)

var input struct {
Header map[string][]string
}

fetch := response.Data.Fetch
switch f := fetch.(type) {
case *SingleFetch:
{
_ = json.Unmarshal(f.Trace.Input, &input)
authHeader := input.Header["Authorization"]
assert.Equal(t, []string{"****"}, authHeader)
}
default:
{
t.Errorf("Incorrect fetch type")
}
}
}

0 comments on commit 50349c2

Please sign in to comment.