diff --git a/internal/pkg/http/client.go b/internal/pkg/http/client.go index d79b7a9..0c67e98 100644 --- a/internal/pkg/http/client.go +++ b/internal/pkg/http/client.go @@ -15,6 +15,7 @@ package http import ( + "bytes" "crypto/tls" "fmt" "io" @@ -73,17 +74,22 @@ func NewClient(host string, insecure bool, timeoutMilliseconds int, protocol Pro } // SendRequest sends a request to the HTTP server and wraps useful information into a Response object. -func (c Client) SendRequest(method, path string, headers map[string]string, body io.Reader) response.Response { +func (c Client) SendRequest(method, path string, headers map[string]string, requestBody *string) response.Response { const respType = "http" - + var body io.Reader + if requestBody != nil { + body = bytes.NewBufferString(*requestBody) + } + url := fmt.Sprintf("%s/%s", c.host, strings.TrimLeft(path, "/")) req, err := http.NewRequest(method, url, body) - if err != nil { log.Printf("Failed to create request: %s %s: %v", method, url, err) return response.Response{Duration: time.Duration(0), Err: err, Type: respType} } - + if req.Body != nil { + defer req.Body.Close() + } for k, v := range headers { if strings.EqualFold(k, "Host") { req.Host = v @@ -94,14 +100,8 @@ func (c Client) SendRequest(method, path string, headers map[string]string, body req.Header.Add(k, interpolatedHeaderValue) } - if err != nil { - defer req.Body.Close() - } startTime := time.Now() resp, err := c.httpClient.Do(req) - if resp != nil { - defer resp.Body.Close() - } endTime := time.Now() if err != nil { return response.Response{Duration: endTime.Sub(startTime), Err: err, Type: respType} diff --git a/internal/pkg/http/client_test.go b/internal/pkg/http/client_test.go index 64315f4..23c8928 100644 --- a/internal/pkg/http/client_test.go +++ b/internal/pkg/http/client_test.go @@ -19,7 +19,6 @@ import ( "fmt" "mittens/fixture" "net/http" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -40,24 +39,21 @@ func TestMain(m *testing.M) { func TestRequestSuccessHTTP1(t *testing.T) { c := NewClient(serverUrl, false, 10000, HTTP1) reqBody := "" - reader := strings.NewReader(reqBody) - resp := c.SendRequest("GET", WorkingPath, make(map[string]string), reader) + resp := c.SendRequest("GET", WorkingPath, make(map[string]string), &reqBody) assert.Nil(t, resp.Err) } func TestRequestSuccessH2C(t *testing.T) { c := NewClient(serverUrl, false, 10000, H2C) reqBody := "" - reader := strings.NewReader(reqBody) - resp := c.SendRequest("GET", WorkingPath, make(map[string]string), reader) + resp := c.SendRequest("GET", WorkingPath, make(map[string]string), &reqBody) assert.Nil(t, resp.Err) } func TestHttpErrorHTTP1(t *testing.T) { c := NewClient(serverUrl, false, 10000, HTTP1) reqBody := "" - reader := strings.NewReader(reqBody) - resp := c.SendRequest("GET", "/", make(map[string]string), reader) + resp := c.SendRequest("GET", "/", make(map[string]string), &reqBody) assert.Nil(t, resp.Err) assert.Equal(t, resp.StatusCode, 404) } @@ -65,8 +61,7 @@ func TestHttpErrorHTTP1(t *testing.T) { func TestHttpErrorH2C(t *testing.T) { c := NewClient(serverUrl, false, 10000, H2C) reqBody := "" - reader := strings.NewReader(reqBody) - resp := c.SendRequest("GET", "/", make(map[string]string), reader) + resp := c.SendRequest("GET", "/", make(map[string]string), &reqBody) assert.Nil(t, resp.Err) assert.Equal(t, resp.StatusCode, 404) } @@ -74,16 +69,14 @@ func TestHttpErrorH2C(t *testing.T) { func TestConnectionErrorHTTP1(t *testing.T) { c := NewClient("http://localhost:9999", false, 10000, HTTP1) reqBody := "" - reader := strings.NewReader(reqBody) - resp := c.SendRequest("GET", "/potato", make(map[string]string), reader) + resp := c.SendRequest("GET", "/potato", make(map[string]string), &reqBody) assert.NotNil(t, resp.Err) } func TestConnectionErrorH2C(t *testing.T) { c := NewClient("http://localhost:9999", false, 10000, H2C) reqBody := "" - reader := strings.NewReader(reqBody) - resp := c.SendRequest("GET", "/potato", make(map[string]string), reader) + resp := c.SendRequest("GET", "/potato", make(map[string]string), &reqBody) assert.NotNil(t, resp.Err) } diff --git a/internal/pkg/http/utils.go b/internal/pkg/http/utils.go index 7386c1e..216b1e7 100644 --- a/internal/pkg/http/utils.go +++ b/internal/pkg/http/utils.go @@ -31,7 +31,7 @@ type Request struct { Method string Headers map[string]string Path string - Body io.Reader + Body *string } type CompressionType string @@ -90,15 +90,23 @@ func ToHTTPRequest(requestString string, compression CompressionType) (Request, var reader io.Reader switch compression { case COMPRESSION_GZIP: - reader = compressGzip([]byte(body)) + reader, err = compressGzip([]byte(body)) case COMPRESSION_BROTLI: - reader = compressBrotli([]byte(body)) + reader, err = compressBrotli([]byte(body)) case COMPRESSION_DEFLATE: - reader = compressFlate([]byte(body)) + reader, err = compressFlate([]byte(body)) default: reader = bytes.NewBufferString(body) } + if err != nil { + return Request{}, fmt.Errorf("unable to compress body for request: %s", parts[2]) + } + + buf := new(bytes.Buffer) + buf.ReadFrom(reader) + compressedBody := buf.String() + headers := make(map[string]string) if compression != COMPRESSION_NONE { encoding := "" @@ -117,33 +125,42 @@ func ToHTTPRequest(requestString string, compression CompressionType) (Request, Method: method, Headers: headers, Path: path, - Body: reader, + Body: &compressedBody, }, nil } -func compressGzip(data []byte) io.Reader { - pr, pw := io.Pipe() - go func() { - gz := gzip.NewWriter(pw) - _, err := gz.Write(data) - gz.Close() - pw.CloseWithError(err) - }() - return pr +func compressGzip(data []byte) (io.Reader, error) { + var b bytes.Buffer + w := gzip.NewWriter(&b) + _, err := w.Write(data); if err != nil { + return nil, err + } + err = w.Close(); if err != nil { + return nil, err + } + return &b, nil } -func compressFlate(data []byte) *bytes.Buffer { +func compressFlate(data []byte) (io.Reader, error) { var b bytes.Buffer w, _ := flate.NewWriter(&b, 9) - w.Write(data) - w.Close() - return &b + _, err := w.Write(data); if err != nil { + return nil, err + } + err = w.Close(); if err != nil { + return nil, err + } + return &b, nil } -func compressBrotli(data []byte) *bytes.Buffer { +func compressBrotli(data []byte) (io.Reader, error) { var b bytes.Buffer w := brotli.NewWriterLevel(&b, brotli.BestCompression) - w.Write(data) - w.Close() - return &b + _, err := w.Write(data); if err != nil { + return nil, err + } + err = w.Close(); if err != nil { + return nil, err + } + return &b, nil } diff --git a/internal/pkg/http/utils_test.go b/internal/pkg/http/utils_test.go index 5ea3e14..38e7c0b 100644 --- a/internal/pkg/http/utils_test.go +++ b/internal/pkg/http/utils_test.go @@ -15,7 +15,6 @@ package http import ( - "bytes" "net/http" "os" "regexp" @@ -34,9 +33,7 @@ func TestHttp_FlagToHttpRequest(t *testing.T) { assert.Equal(t, http.MethodPost, request.Method) assert.Equal(t, "/db", request.Path) - body := new(bytes.Buffer) - body.ReadFrom(request.Body) - assert.Equal(t, `{"db": "true"}`, body.String()) + assert.Equal(t, `{"db": "true"}`, *request.Body) } func TestHttp_CompressGzip(t *testing.T) { @@ -49,11 +46,9 @@ func TestHttp_CompressGzip(t *testing.T) { assert.Equal(t, map[string]string{"Content-Encoding": "gzip"}, request.Headers) - body := new(bytes.Buffer) - body.ReadFrom(request.Body) - expected := &bytes.Buffer{} - expected.Write([]byte{0x1f, 0x8b, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xaa, 0x56, 0x4a, 0x49, 0x52, 0xb2, 0x52, 0x50, 0x2a, 0x29, 0x2a, 0x4d, 0x55, 0xaa, 0x5, 0x4, 0x0, 0x0, 0xff, 0xff, 0xa1, 0x4a, 0x9b, 0x5d, 0xe, 0x0, 0x0, 0x0}) - assert.Equal(t, expected, body) + expected := []byte{0x1f, 0x8b, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xaa, 0x56, 0x4a, 0x49, 0x52, 0xb2, 0x52, 0x50, 0x2a, 0x29, 0x2a, 0x4d, 0x55, 0xaa, 0x5, 0x4, 0x0, 0x0, 0xff, 0xff, 0xa1, 0x4a, 0x9b, 0x5d, 0xe, 0x0, 0x0, 0x0} + actual := []byte(*request.Body) + assert.Equal(t, expected, actual) } func TestHttp_CompressBrotli(t *testing.T) { @@ -66,11 +61,9 @@ func TestHttp_CompressBrotli(t *testing.T) { assert.Equal(t, map[string]string{"Content-Encoding": "br"}, request.Headers) - body := new(bytes.Buffer) - body.ReadFrom(request.Body) - expected := &bytes.Buffer{} - expected.Write([]byte{0x8b, 0x6, 0x80, 0x7b, 0x22, 0x64, 0x62, 0x22, 0x3a, 0x20, 0x22, 0x74, 0x72, 0x75, 0x65, 0x22, 0x7d, 0x3}) - assert.Equal(t, expected, body) + expected := []byte{0x8b, 0x6, 0x80, 0x7b, 0x22, 0x64, 0x62, 0x22, 0x3a, 0x20, 0x22, 0x74, 0x72, 0x75, 0x65, 0x22, 0x7d, 0x3} + actual := []byte(*request.Body) + assert.Equal(t, expected, actual) } func TestHttp_CompressDeflate(t *testing.T) { @@ -81,11 +74,9 @@ func TestHttp_CompressDeflate(t *testing.T) { assert.Equal(t, http.MethodPost, request.Method) assert.Equal(t, "/db", request.Path) assert.Equal(t, map[string]string{"Content-Encoding": "deflate"}, request.Headers) - body := new(bytes.Buffer) - body.ReadFrom(request.Body) - expected := &bytes.Buffer{} - expected.Write([]byte{0xaa, 0x56, 0x4a, 0x49, 0x52, 0xb2, 0x52, 0x50, 0x2a, 0x29, 0x2a, 0x4d, 0x55, 0xaa, 0x5, 0x4, 0x0, 0x0, 0xff, 0xff}) - assert.Equal(t, expected, body) + expected := []byte{0xaa, 0x56, 0x4a, 0x49, 0x52, 0xb2, 0x52, 0x50, 0x2a, 0x29, 0x2a, 0x4d, 0x55, 0xaa, 0x5, 0x4, 0x0, 0x0, 0xff, 0xff} + actual := []byte(*request.Body) + assert.Equal(t, expected, actual) } func TestBodyFromFile(t *testing.T) { @@ -100,9 +91,7 @@ func TestBodyFromFile(t *testing.T) { assert.Equal(t, http.MethodPost, request.Method) assert.Equal(t, "/db", request.Path) - buf := new(bytes.Buffer) - buf.ReadFrom(request.Body) - assert.Equal(t, `{"foo": "bar"}`, buf.String()) + assert.Equal(t, `{"foo": "bar"}`, *request.Body) } func TestHttp_FlagWithoutBodyToHttpRequest(t *testing.T) { @@ -130,14 +119,12 @@ func TestHttp_TimestampInterpolation(t *testing.T) { var numbersRegex = regexp.MustCompile("\\d+") matchPath := numbersRegex.MatchString(request.Path) - body := new(bytes.Buffer) - body.ReadFrom(request.Body) - matchBody := numbersRegex.MatchString(body.String()) + matchBody := numbersRegex.MatchString(*request.Body) assert.True(t, matchPath) assert.True(t, matchBody) assert.Equal(t, len(request.Path), 19) // "path_ + 13 numbers for timestamp - assert.Equal(t, len(body.String()), 25) // { "body": 13 numbers for timestamp + assert.Equal(t, len(*request.Body), 25) // { "body": 13 numbers for timestamp } func TestHttp_Interpolation(t *testing.T) { @@ -151,9 +138,7 @@ func TestHttp_Interpolation(t *testing.T) { matchPath := pathRegex.MatchString(request.Path) var bodyRegex = regexp.MustCompile("{\"body\": \"(foo|bar) \\d\"}") - body := new(bytes.Buffer) - body.ReadFrom(request.Body) - matchBody := bodyRegex.MatchString(body.String()) + matchBody := bodyRegex.MatchString(*request.Body) assert.True(t, matchPath) assert.True(t, matchBody) diff --git a/test/root_test.go b/test/root_test.go index 3d94eee..2d3afed 100644 --- a/test/root_test.go +++ b/test/root_test.go @@ -15,8 +15,10 @@ package test import ( + "compress/gzip" "context" "fmt" + "io" "mittens/cmd" "mittens/fixture" "mittens/internal/pkg/probe" @@ -34,6 +36,7 @@ var mockHttpServerPort int var mockHttpServer *http.Server var mockGrpcServer *grpc.Server var httpInvocations = 0 +var decompressedBody = "" func TestMain(m *testing.M) { setup() @@ -188,6 +191,36 @@ func TestGrpcAndHttp(t *testing.T) { assert.True(t, readyFileExists) } +func TestCompressWithGZip(t *testing.T) { + t.Cleanup(func() { + cleanup() + }) + + var requestBody = "{\"payload\":{\"body\":\"abcdefghijklmnopqrstuvwxyz01\"}}" + + os.Args = []string{ + "mittens", + "-file-probe-enabled=true", + // FIXME: for some reason we need to set both ports? + fmt.Sprintf("-target-http-port=%d", mockHttpServerPort), + fmt.Sprintf("-target-readiness-port=%d", mockHttpServerPort), + "-http-requests=post:/compressed:" + requestBody, + "-http-requests-compression=gzip", + "-target-insecure=true", + "-concurrency=2", + "-exit-after-warmup=true", + "-target-readiness-http-path=/health", + "-max-duration-seconds=2", + "-concurrency-target-seconds=1", + } + + cmd.CreateConfig() + cmd.RunCmdRoot() + + assert.Greater(t, httpInvocations, 1, "Assert that we made some calls to the http service") + assert.Equal(t, requestBody, decompressedBody, "Assert that server-side decompressed body is equal to client request body") +} + func setup() { fmt.Println("Starting up http server") mockHttpServer, mockHttpServerPort = fixture.StartHttpTargetTestServer([]fixture.PathResponseHandler{ @@ -201,6 +234,32 @@ func setup() { w.WriteHeader(http.StatusOK) }, }, + { + Path: "/compressed", + PathHandlerFunc: func(w http.ResponseWriter, r *http.Request) { + // Tiny sleep to mimic a regular http call + time.Sleep(time.Millisecond * 10) + httpInvocations++ + // Record number of invocations made to this endpoint + if r.Header.Get("Content-Encoding") == "gzip" { + gr, err := gzip.NewReader(r.Body) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + b, err := io.ReadAll(gr) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + decompressedBody = string(b) + gr.Close() + w.WriteHeader(http.StatusOK) + return + } + w.WriteHeader(http.StatusOK) + }, + }, }) // FIXME: should run on a random/free port