From 627a7ced940f200b2daab1fda063bdfcfc8a910f Mon Sep 17 00:00:00 2001 From: Pedro Costa Date: Mon, 24 Jun 2024 11:47:14 -0300 Subject: [PATCH] feat: support file upload in router (#758) This PR makes possible to support a file upload according to spec: https://github.com/jaydenseric/graphql-multipart-request-spec With this PR merged to main and released we can discuss and work on PR to support the feature in cosmo router: https://github.com/wundergraph/cosmo/pull/652 Basically, this PR makes possible to pass a temporary file information stored in file system to structures responsible for graphql resolve operation. Additionally, introduces a method in nethttpclient.go to actually perform the multipart http request. --------- Co-authored-by: pedraumcosta Co-authored-by: thisisnithin --- v2/pkg/astvalidation/reference/main.go | 2 - .../graphql_datasource/graphql_datasource.go | 5 + .../graphql_datasource_test.go | 153 +++++++++++++ v2/pkg/engine/datasource/httpclient/file.go | 26 +++ .../datasource/httpclient/nethttpclient.go | 202 +++++++++++++++--- .../introspection_datasource/source.go | 5 + .../pubsub_datasource/pubsub_kafka.go | 5 + .../pubsub_datasource/pubsub_nats.go | 9 + .../staticdatasource/static_datasource.go | 5 + v2/pkg/engine/plan/schemausageinfo_test.go | 5 + v2/pkg/engine/resolve/context.go | 4 + v2/pkg/engine/resolve/datasource.go | 2 + v2/pkg/engine/resolve/loader.go | 14 +- v2/pkg/engine/resolve/resolve.go | 2 +- v2/pkg/engine/resolve/resolve_mock_test.go | 50 ++--- v2/pkg/engine/resolve/resolve_test.go | 16 +- .../variablesvalidation.go | 5 +- 17 files changed, 435 insertions(+), 75 deletions(-) create mode 100644 v2/pkg/engine/datasource/httpclient/file.go diff --git a/v2/pkg/astvalidation/reference/main.go b/v2/pkg/astvalidation/reference/main.go index 0580ccd66..ce69a406d 100644 --- a/v2/pkg/astvalidation/reference/main.go +++ b/v2/pkg/astvalidation/reference/main.go @@ -11,8 +11,6 @@ import ( "gopkg.in/yaml.v2" ) -//go:generate ./gen.sh - func main() { currDir, _ := os.Getwd() println(currDir) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index bd340f72c..0273e9558 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -1699,6 +1699,11 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, writer io.Writer) (err error) { + input = s.compactAndUnNullVariables(input) + return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, writer) +} + func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err error) { input = s.compactAndUnNullVariables(input) return httpclient.Do(s.httpClient, ctx, input, writer) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 6e8560be4..9028661bb 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -6,9 +6,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/google/uuid" "io" "net/http" "net/http/httptest" + "os" + "runtime" + "strconv" "sync" "testing" "time" @@ -9152,6 +9156,155 @@ func TestSource_Load(t *testing.T) { }) } +type ExpectedFile struct { + Name string + Size int64 +} + +type ExpectedRequest struct { + Operations string + Map string + Files []ExpectedFile +} + +func verifyMultipartRequest(t *testing.T, r *http.Request, expected ExpectedRequest) { + err := r.ParseMultipartForm(10 << 20) + require.NoError(t, err) + + for key, values := range r.MultipartForm.Value { + switch key { + case "operations": + assert.Equal(t, expected.Operations, values[0]) + case "map": + assert.Equal(t, expected.Map, values[0]) + } + } + + for i, expectedFile := range expected.Files { + values, exists := r.MultipartForm.File[strconv.Itoa(i)] + if !exists { + t.Fatalf("expected file %s not found in MultipartForm.File", expectedFile.Name) + } + assert.Equal(t, values[0].Filename, expectedFile.Name) + assert.Equal(t, values[0].Size, expectedFile.Size) + } +} + +func TestLoadFiles(t *testing.T) { + if runtime.GOOS == "windows" { + t.SkipNow() + } + + t.Run("single file", func(t *testing.T) { + queryString := `mutation($file: Upload!){singleUpload(file: $file)}` + variableString := `{"file":null}` + fileName := uuid.NewString() + fileContent := "hello" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedFiles := []ExpectedFile{{ + Name: fileName, + Size: int64(len(fileContent)), + }} + verifyMultipartRequest(t, r.Clone(r.Context()), ExpectedRequest{ + Operations: fmt.Sprintf(`{"query":"%s","variables":%s}`, queryString, variableString), + Map: `{ "0" : ["variables.file"] }`, + Files: expectedFiles, + }) + body, _ := io.ReadAll(r.Body) + _, _ = fmt.Fprint(w, string(body)) + })) + defer ts.Close() + + var ( + src = &Source{httpClient: &http.Client{}} + serverUrl = ts.URL + variables = []byte(variableString) + query = []byte(queryString) + ) + + dir := t.TempDir() + f, err := os.CreateTemp(dir, fileName) + assert.NoError(t, err) + err = os.WriteFile(f.Name(), []byte(fileContent), 0644) + assert.NoError(t, err) + + var input []byte + input = httpclient.SetInputBodyWithPath(input, variables, "variables") + input = httpclient.SetInputBodyWithPath(input, query, "query") + input = httpclient.SetInputURL(input, []byte(serverUrl)) + buf := bytes.NewBuffer(nil) + + ctx := context.Background() + require.NoError(t, src.LoadWithFiles( + ctx, + input, + []httpclient.File{httpclient.NewFile(f.Name(), fileName)}, + buf, + )) + }) + + t.Run("multiple files", func(t *testing.T) { + queryString := `mutation($files: [Upload!]!) { multipleUpload(files: $files)}` + variableString := `{"files":[null,null]}` + + file1Name := uuid.NewString() + file2Name := uuid.NewString() + file1Content := "test" + file2Content := "hello" + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedFiles := []ExpectedFile{{ + Name: file1Name, + Size: int64(len(file1Content)), + }, { + Name: file2Name, + Size: int64(len(file2Content)), + }} + verifyMultipartRequest(t, r.Clone(r.Context()), ExpectedRequest{ + Operations: fmt.Sprintf(`{"query":"%s","variables":%s}`, queryString, variableString), + Map: `{ "0" : ["variables.files.0"], "1" : ["variables.files.1"] }`, + Files: expectedFiles, + }) + body, _ := io.ReadAll(r.Body) + _, _ = fmt.Fprint(w, string(body)) + })) + defer ts.Close() + + var ( + src = &Source{httpClient: &http.Client{}} + serverUrl = ts.URL + variables = []byte(variableString) + query = []byte(queryString) + ) + + var input []byte + input = httpclient.SetInputBodyWithPath(input, variables, "variables") + input = httpclient.SetInputBodyWithPath(input, query, "query") + input = httpclient.SetInputURL(input, []byte(serverUrl)) + buf := bytes.NewBuffer(nil) + + dir := t.TempDir() + f1, err := os.CreateTemp(dir, file1Name) + assert.NoError(t, err) + err = os.WriteFile(f1.Name(), []byte(file1Content), 0644) + assert.NoError(t, err) + + f2, err := os.CreateTemp(dir, file2Name) + assert.NoError(t, err) + err = os.WriteFile(f2.Name(), []byte(file2Content), 0644) + assert.NoError(t, err) + + ctx := context.Background() + require.NoError(t, src.LoadWithFiles( + ctx, + input, + []httpclient.File{httpclient.NewFile(f1.Name(), file1Name), httpclient.NewFile(f2.Name(), file2Name)}, + buf, + )) + }) +} + func TestUnNullVariables(t *testing.T) { t.Run("should not unnull variables if not enabled", func(t *testing.T) { t.Run("two variables, one null", func(t *testing.T) { diff --git a/v2/pkg/engine/datasource/httpclient/file.go b/v2/pkg/engine/datasource/httpclient/file.go new file mode 100644 index 000000000..887c8bf2a --- /dev/null +++ b/v2/pkg/engine/datasource/httpclient/file.go @@ -0,0 +1,26 @@ +package httpclient + +type File interface { + Path() string + Name() string +} + +type internalFile struct { + path string + name string +} + +func NewFile(path string, name string) File { + return &internalFile{ + path: path, + name: name, + } +} + +func (f *internalFile) Path() string { + return f.path +} + +func (f *internalFile) Name() string { + return f.name +} diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 82618e3f5..35eacaa20 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -1,13 +1,18 @@ package httpclient import ( + "bufio" "bytes" "compress/flate" "compress/gzip" "context" "encoding/json" + "errors" + "fmt" "io" + "mime/multipart" "net/http" + "os" "slices" "strings" "time" @@ -78,11 +83,40 @@ func setResponseStatusCode(ctx context.Context, statusCode int) { } } -func Do(client *http.Client, ctx context.Context, requestInput []byte, out io.Writer) (err error) { +var headersToRedact = []string{ + "authorization", + "www-authenticate", + "proxy-authenticate", + "proxy-authorization", + "cookie", + "set-cookie", +} - url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) +func redactHeaders(headers http.Header) http.Header { + redactedHeaders := make(http.Header) + for key, values := range headers { + if slices.Contains(headersToRedact, strings.ToLower(key)) { + redactedHeaders[key] = []string{"****"} + } else { + redactedHeaders[key] = values + } + } + return redactedHeaders +} - request, err := http.NewRequestWithContext(ctx, string(method), string(url), bytes.NewReader(body)) +func respBodyReader(res *http.Response) (io.Reader, error) { + switch res.Header.Get(ContentEncodingHeader) { + case EncodingGzip: + return gzip.NewReader(res.Body) + case EncodingDeflate: + return flate.NewReader(res.Body), nil + default: + return res.Body, nil + } +} + +func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out io.Writer, contentType string) (err error) { + request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { return err } @@ -136,7 +170,7 @@ func Do(client *http.Client, ctx context.Context, requestInput []byte, out io.Wr } request.Header.Add(AcceptHeader, ContentTypeJSON) - request.Header.Add(ContentTypeHeader, ContentTypeJSON) + request.Header.Add(ContentTypeHeader, contentType) request.Header.Set(AcceptEncodingHeader, EncodingGzip) request.Header.Add(AcceptEncodingHeader, EncodingDeflate) @@ -188,34 +222,150 @@ func Do(client *http.Client, ctx context.Context, requestInput []byte, out io.Wr return err } -var headersToRedact = []string{ - "authorization", - "www-authenticate", - "proxy-authenticate", - "proxy-authorization", - "cookie", - "set-cookie", +func Do(client *http.Client, ctx context.Context, requestInput []byte, out io.Writer) (err error) { + url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) + + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) } -func redactHeaders(headers http.Header) http.Header { - redactedHeaders := make(http.Header) - for key, values := range headers { - if slices.Contains(headersToRedact, strings.ToLower(key)) { - redactedHeaders[key] = []string{"****"} +func DoMultipartForm( + client *http.Client, ctx context.Context, requestInput []byte, files []File, out io.Writer, +) (err error) { + if len(files) == 0 { + return errors.New("no files provided") + } + + url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) + + formValues := map[string]io.Reader{ + "operations": bytes.NewReader(body), + } + + var fileMap string + var tempFiles []*os.File + for i, file := range files { + if len(fileMap) == 0 { + if len(files) == 1 { + fileMap = fmt.Sprintf(`"%d" : ["variables.file"]`, i) + } else { + fileMap = fmt.Sprintf(`"%d" : ["variables.files.%d"]`, i, i) + } } else { - redactedHeaders[key] = values + fileMap = fmt.Sprintf(`%s, "%d" : ["variables.files.%d"]`, fileMap, i, i) } + key := fmt.Sprintf("%d", i) + temporaryFile, err := os.Open(file.Path()) + tempFiles = append(tempFiles, temporaryFile) + if err != nil { + return err + } + formValues[key] = bufio.NewReader(temporaryFile) } - return redactedHeaders + formValues["map"] = strings.NewReader("{ " + fileMap + " }") + + multipartBody, contentType, err := multipartBytes(formValues, files) + if err != nil { + return err + } + + defer func() { + multipartBody.Close() + for _, file := range tempFiles { + if err := file.Close(); err != nil { + return + } + if err = os.Remove(file.Name()); err != nil { + return + } + } + }() + + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, out, contentType) } -func respBodyReader(res *http.Response) (io.Reader, error) { - switch res.Header.Get(ContentEncodingHeader) { - case EncodingGzip: - return gzip.NewReader(res.Body) - case EncodingDeflate: - return flate.NewReader(res.Body), nil - default: - return res.Body, nil +func multipartBytes(values map[string]io.Reader, files []File) (*io.PipeReader, string, error) { + byteBuf := &bytes.Buffer{} + mpWriter := multipart.NewWriter(byteBuf) + contentType := mpWriter.FormDataContentType() + + // First create the fields to control the file upload + valuesInOrder := []string{"operations", "map"} + for _, key := range valuesInOrder { + r := values[key] + fw, err := mpWriter.CreateFormField(key) + if err != nil { + return nil, contentType, err + } + if _, err = io.Copy(fw, r); err != nil { + return nil, contentType, err + } } + + // Insert parts for files + boundaries := make([][]byte, 0, len(files)) + for i, file := range files { + key := fmt.Sprintf("%d", i) + _, err := mpWriter.CreateFormFile(key, file.Name()) + if err != nil { + return nil, contentType, err + } + + // We read the files using pipe later + // So we need to keep store boundaries to insert contents in the correct place + lengthOfBufferTillBoundary := byteBuf.Len() + boundary := make([]byte, lengthOfBufferTillBoundary) + if _, err = byteBuf.Read(boundary); err != nil { + return nil, contentType, err + } + boundaries = append(boundaries, boundary) + } + + err := mpWriter.Close() + if err != nil { + return nil, contentType, err + } + + rd, wr := io.Pipe() + + go func() { + defer func() { + err := wr.Close() + if err != nil { + fmt.Println("Error closing pipe: ", err) + } + }() + + // 4MB chunks + buf := make([]byte, 2048*2048) + for i, file := range files { + if _, err = wr.Write(boundaries[i]); err != nil { + return + } + + f, err := os.Open(file.Path()) + if err != nil { + return + } + + for { + n, err := f.Read(buf) + if err != nil && err == io.EOF { + break + } else if err != nil { + return + } + + if _, err = wr.Write(buf[:n]); err != nil { + return + } + } + if err := f.Close(); err != nil { + return + } + } + // Write last boundary + _, _ = wr.Write(byteBuf.Bytes()) + }() + + return rd, contentType, nil } diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index 1ae5624a4..3f6b20608 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -3,6 +3,7 @@ package introspection_datasource import ( "context" "encoding/json" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "io" "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" @@ -34,6 +35,10 @@ func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error return json.NewEncoder(w).Encode(s.schemaWithoutTypeInfo()) } +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { + panic("not implemented") +} + func (s *Source) schemaWithoutTypeInfo() introspection.Schema { types := make([]introspection.FullType, 0, len(s.introspectionData.Schema.Types)) diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index 5d040d2c0..9f4c4fe82 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "io" ) @@ -78,3 +79,7 @@ func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, w io.Wr _, err = io.WriteString(w, `{"success": true}`) return err } + +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { + panic("not implemented") +} diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index ce7e1c690..1088cd810 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "io" ) @@ -88,6 +89,10 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, w io.Wri return err } +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error { + panic("not implemented") +} + type NatsRequestDataSource struct { pubSub NatsPubSub } @@ -101,3 +106,7 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, w io.Wri return s.pubSub.Request(ctx, subscriptionConfiguration, w) } + +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error { + panic("not implemented") +} diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index ef64bacc3..0e4a53112 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -7,6 +7,7 @@ import ( "github.com/jensneuse/abstractlogger" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -69,3 +70,7 @@ func (Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) { _, err = w.Write(input) return } + +func (Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { + panic("not implemented") +} diff --git a/v2/pkg/engine/plan/schemausageinfo_test.go b/v2/pkg/engine/plan/schemausageinfo_test.go index 6258449b9..4e8f260bf 100644 --- a/v2/pkg/engine/plan/schemausageinfo_test.go +++ b/v2/pkg/engine/plan/schemausageinfo_test.go @@ -15,6 +15,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafeparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" @@ -489,3 +490,7 @@ type FakeDataSource struct { func (f *FakeDataSource) Load(ctx context.Context, input []byte, w io.Writer) (err error) { return } + +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { + return +} diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 736e4f4fd..86b44a337 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -8,12 +8,14 @@ import ( "net/http" "time" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "go.uber.org/atomic" ) type Context struct { ctx context.Context Variables []byte + Files []httpclient.File Request Request RenameTypeNames []RenameTypeName TracingOptions TraceOptions @@ -141,6 +143,7 @@ func (c *Context) clone(ctx context.Context) *Context { cpy := *c cpy.ctx = ctx cpy.Variables = append([]byte(nil), c.Variables...) + cpy.Files = append([]httpclient.File(nil), c.Files...) cpy.Request.Header = c.Request.Header.Clone() cpy.RenameTypeNames = append([]RenameTypeName(nil), c.RenameTypeNames...) return &cpy @@ -149,6 +152,7 @@ func (c *Context) clone(ctx context.Context) *Context { func (c *Context) Free() { c.ctx = nil c.Variables = nil + c.Files = nil c.Request.Header = nil c.RenameTypeNames = nil c.TracingOptions.DisableAll() diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index 1ee3e649a..4d1c3f0d3 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -5,10 +5,12 @@ import ( "io" "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) type DataSource interface { Load(ctx context.Context, input []byte, w io.Writer) (err error) + LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) } type SubscriptionDataSource interface { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 0bcfa5088..58a8f2cfd 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1271,6 +1271,14 @@ func (l *Loader) setTracingInput(input []byte, trace *DataSourceLoadTrace) { } } +func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []byte, res *result) error { + if l.ctx.Files != nil { + return source.LoadWithFiles(ctx, input, l.ctx.Files, res.out) + } + + return source.Load(ctx, input, res.out) +} + func (l *Loader) executeSourceLoad(ctx context.Context, source DataSource, input []byte, res *result, trace *DataSourceLoadTrace) { if l.ctx.Extensions != nil { input, res.err = jsonparser.Set(input, l.ctx.Extensions, "body", "extensions") @@ -1395,13 +1403,13 @@ func (l *Loader) executeSourceLoad(ctx context.Context, source DataSource, input // Prevent that the context is destroyed when the loader hook return an empty context if res.loaderHookContext != nil { - res.err = source.Load(res.loaderHookContext, input, res.out) + res.err = l.loadByContext(res.loaderHookContext, source, input, res) } else { - res.err = source.Load(ctx, input, res.out) + res.err = l.loadByContext(ctx, source, input, res) } } else { - res.err = source.Load(ctx, input, res.out) + res.err = l.loadByContext(ctx, source, input, res) } res.statusCode = responseContext.StatusCode diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 846f87022..876294ec2 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -1,4 +1,4 @@ -//go:generate mockgen --build_flags=--mod=mod -self_package=github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve -destination=resolve_mock_test.go -package=resolve . DataSource,BeforeFetchHook,AfterFetchHook +//go:generate mockgen -self_package=github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve -destination=resolve_mock_test.go -package=resolve . DataSource package resolve diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index 7b512808d..1c011bfda 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve (interfaces: DataSource,BeforeFetchHook,AfterFetchHook) +// Source: github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve (interfaces: DataSource) // Package resolve is a generated GoMock package. package resolve @@ -10,6 +10,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + httpclient "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) // MockDataSource is a mock of DataSource interface. @@ -49,43 +50,16 @@ func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockDataSource)(nil).Load), arg0, arg1, arg2) } -// MockBeforeFetchHook is a mock of BeforeFetchHook interface. -type MockBeforeFetchHook struct { - ctrl *gomock.Controller - recorder *MockBeforeFetchHookMockRecorder -} - -// MockBeforeFetchHookMockRecorder is the mock recorder for MockBeforeFetchHook. -type MockBeforeFetchHookMockRecorder struct { - mock *MockBeforeFetchHook -} - -// NewMockBeforeFetchHook creates a new mock instance. -func NewMockBeforeFetchHook(ctrl *gomock.Controller) *MockBeforeFetchHook { - mock := &MockBeforeFetchHook{ctrl: ctrl} - mock.recorder = &MockBeforeFetchHookMockRecorder{mock} - return mock -} - -// MockAfterFetchHook is a mock of AfterFetchHook interface. -type MockAfterFetchHook struct { - ctrl *gomock.Controller - recorder *MockAfterFetchHookMockRecorder -} - -// MockAfterFetchHookMockRecorder is the mock recorder for MockAfterFetchHook. -type MockAfterFetchHookMockRecorder struct { - mock *MockAfterFetchHook -} - -// NewMockAfterFetchHook creates a new mock instance. -func NewMockAfterFetchHook(ctrl *gomock.Controller) *MockAfterFetchHook { - mock := &MockAfterFetchHook{ctrl: ctrl} - mock.recorder = &MockAfterFetchHookMockRecorder{mock} - return mock +// LoadWithFiles mocks base method. +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []httpclient.File, arg3 io.Writer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 } -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockAfterFetchHook) EXPECT() *MockAfterFetchHookMockRecorder { - return m.recorder +// LoadWithFiles indicates an expected call of LoadWithFiles. +func (mr *MockDataSourceMockRecorder) LoadWithFiles(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadWithFiles", reflect.TypeOf((*MockDataSource)(nil).LoadWithFiles), arg0, arg1, arg2, arg3) } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 8b13c9d3c..745169959 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -1,7 +1,5 @@ package resolve -// go:generate mockgen -package resolve -destination resolve_mock_test.go . DataSource,BeforeFetchHook,AfterFetchHook,DataSourceBatch,DataSourceBatchFactory - import ( "bytes" "context" @@ -20,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" ) @@ -43,6 +42,19 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte, w io.Writer) ( return } +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { + if f.artificialLatency != 0 { + time.Sleep(f.artificialLatency) + } + if f.input != nil { + if !bytes.Equal(f.input, input) { + require.Equal(f.t, string(f.input), string(input), "input mismatch") + } + } + _, err = w.Write(f.data) + return +} + func FakeDataSource(data string) *_fakeDataSource { return &_fakeDataSource{ data: []byte(data), diff --git a/v2/pkg/variablesvalidation/variablesvalidation.go b/v2/pkg/variablesvalidation/variablesvalidation.go index e46145400..d7258de6b 100644 --- a/v2/pkg/variablesvalidation/variablesvalidation.go +++ b/v2/pkg/variablesvalidation/variablesvalidation.go @@ -125,12 +125,13 @@ func (v *variablesVisitor) EnterVariableDefinition(ref int) { } func (v *variablesVisitor) traverseOperationType(jsonFieldRef int, operationTypeRef int) { + varTypeName := v.operation.ResolveTypeNameBytes(operationTypeRef) if v.operation.TypeIsNonNull(operationTypeRef) { if jsonFieldRef == -1 { v.renderVariableRequiredError(v.currentVariableName, operationTypeRef) return } - if v.variables.Nodes[jsonFieldRef].Kind == astjson.NodeKindNull { + if v.variables.Nodes[jsonFieldRef].Kind == astjson.NodeKindNull && varTypeName.String() != "Upload" { v.renderVariableInvalidNullError(v.currentVariableName, operationTypeRef) return } @@ -143,8 +144,6 @@ func (v *variablesVisitor) traverseOperationType(jsonFieldRef int, operationType return } - varTypeName := v.operation.ResolveTypeNameBytes(operationTypeRef) - if v.operation.TypeIsList(operationTypeRef) { if v.variables.Nodes[jsonFieldRef].Kind != astjson.NodeKindArray { v.renderVariableInvalidObjectTypeError(varTypeName, v.variables.Nodes[jsonFieldRef])