From 0bdc66bee798662358f7be53831781b835330b72 Mon Sep 17 00:00:00 2001 From: Trevor Bramwell Date: Mon, 29 Apr 2024 14:25:04 -0700 Subject: [PATCH 1/3] AWS Lambda: Update Lambda Response Error Handling Updates invokeFunction to parse lambda error responses, log the error, and return a 502. Signed-off-by: Trevor Bramwell --- pkg/middlewares/awslambda/aws_lambda.go | 56 +++++++++----- pkg/middlewares/awslambda/aws_lambda_test.go | 77 +++++++++++++++++++- 2 files changed, 112 insertions(+), 21 deletions(-) diff --git a/pkg/middlewares/awslambda/aws_lambda.go b/pkg/middlewares/awslambda/aws_lambda.go index 5609e2fe35..21661af356 100644 --- a/pkg/middlewares/awslambda/aws_lambda.go +++ b/pkg/middlewares/awslambda/aws_lambda.go @@ -41,6 +41,7 @@ import ( "strings" "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-lambda-go/lambda/messages" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" @@ -67,8 +68,7 @@ type awsLambda struct { // New builds a new AwsLambda middleware. func New(ctx context.Context, next http.Handler, config dynamic.AWSLambda, name string) (http.Handler, error) { - logger := log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)) - logger.Debug("Creating middleware") + log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)).Debug("Creating middleware") if len(config.FunctionArn) == 0 { return nil, fmt.Errorf("function arn cannot be empty") @@ -131,16 +131,15 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) { // ServeHTTP is the AWS Lambda middleware that takes a request, converts // it to an APIGatewayProxyRequest and invokes lambda. It should come at -// the end of a middlware chain as it does routing internally. +// the end of a middleware chain as it does routing internally. // NOTE: While this could implement the same code as Lambda Invoke // (ie: Construct request object, modify request to POST // .../functions/..., sign request) no request middleware could be used // afterwards as it would break the signature. func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - ctx := middlewares.GetLoggerCtx(req.Context(), a.name, typeName) - logger := log.FromContext(ctx) + logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName)) - base64Encoded, body, err := bodyToBase64(req) + base64Encoded, reqBody, err := bodyToBase64(req) if err != nil { msg := fmt.Sprintf("Error encoding Lambda request body: %v", err) logger.Error(msg) @@ -161,7 +160,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) { MultiValueQueryStringParameters: valuesToMultiMap(req.URL.Query()), Headers: headersToMap(req.Header), MultiValueHeaders: headersToMultiMap(req.Header), - Body: body, + Body: reqBody, IsBase64Encoded: base64Encoded, RequestContext: events.APIGatewayProxyRequestContext{ Authorizer: make(map[string]interface{}), @@ -172,11 +171,20 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger.Error(msg) tracing.SetErrorWithEvent(req, msg) - rw.WriteHeader(http.StatusInternalServerError) + statusCode := http.StatusInternalServerError + // If there's an error invoking the lambda, a response and error + // will be returned. Use the statuscode of the error response (502) + // to indicate a lambda error. + if resp != nil { + statusCode = resp.StatusCode + } + + tracing.LogResponseCode(tracing.GetSpan(req), statusCode) + rw.WriteHeader(statusCode) return } - body = resp.Body + body := resp.Body if resp.IsBase64Encoded { buf, err := base64.StdEncoding.DecodeString(body) if err != nil { @@ -184,6 +192,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger.Error(msg) tracing.SetErrorWithEvent(req, msg) + tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError) rw.WriteHeader(http.StatusInternalServerError) return } @@ -207,6 +216,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger.Error(msg) tracing.SetErrorWithEvent(req, msg) + tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError) rw.WriteHeader(http.StatusInternalServerError) return } @@ -218,6 +228,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) { msg := fmt.Sprintf("Failed to write response body %s: %v", body, err) logger.Error(msg) tracing.SetErrorWithEvent(req, msg) + tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError) return } } @@ -269,12 +280,11 @@ func bodyToBase64(req *http.Request) (bool, string, error) { return base64Encoded, body, nil } -func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { - var resp events.APIGatewayProxyResponse +func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) { payload, err := json.Marshal(request) if err != nil { - return resp, fmt.Errorf("failed to marshal request: %w", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) } result, err := a.client.Invoke(ctx, &lambda.InvokeInput{ @@ -282,19 +292,31 @@ func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewa Payload: payload, }) if err != nil { - return resp, err + return nil, err + } + if result == nil { + return nil, fmt.Errorf("Nil lambda result when calling %s", a.functionArn) } - if result.StatusCode >= 300 { - return resp, fmt.Errorf("call to lambda failed with: HTTP %d", result.StatusCode) + var resp events.APIGatewayProxyResponse + // If invoking the lambda resulted in an error, return a 502 and + // set the response body to the lambda error payload + if result.FunctionError != nil || (result.StatusCode >= 300 || result.StatusCode < 200) { + resp.StatusCode = http.StatusBadGateway + var errResp messages.InvokeResponse_Error + err = json.Unmarshal(result.Payload, &errResp) + if err != nil { + return nil, fmt.Errorf("Failed to parse lambda error: %w", err) + } + return &resp, fmt.Errorf("%s: %s", errResp.Message, errResp.Type) } err = json.Unmarshal(result.Payload, &resp) if err != nil { - return resp, fmt.Errorf("failed to unmarshal response: %s, %w", result.Payload, err) + return nil, fmt.Errorf("failed to unmarshal response: %s, %w", result.Payload, err) } - return resp, nil + return &resp, nil } func headersToMap(h http.Header) map[string]string { diff --git a/pkg/middlewares/awslambda/aws_lambda_test.go b/pkg/middlewares/awslambda/aws_lambda_test.go index 1ffca9fec1..897d7eda80 100644 --- a/pkg/middlewares/awslambda/aws_lambda_test.go +++ b/pkg/middlewares/awslambda/aws_lambda_test.go @@ -18,7 +18,76 @@ import ( "github.com/traefik/traefik/v2/pkg/tracing" ) +// Setup provides a mockserver (lambda handler), and should be closed +// with 'defer func() { mockserver.Close() }()' once recieved +func setup(t *testing.T, response string) (*httptest.Server, http.Handler, *http.Request) { + mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + + var buf bytes.Buffer + _, err := io.Copy(&buf, req.Body) + if err != nil { + t.Fatal(err) + } + + var lReq events.APIGatewayProxyRequest + err = json.Unmarshal(buf.Bytes(), &lReq) + if err != nil { + t.Fatal(err) + } + + res.WriteHeader(http.StatusOK) + _, err = res.Write([]byte(response)) + if err != nil { + t.Fatal(err) + } + })) + + cfg := dynamic.AWSLambda{ + AccessKey: "aws-key", + Region: "us-west-2", + SecretKey: "@@not-a-key", + FunctionArn: "arn:aws:lambda:us-west-2:000000000000:function:xxx:1", + Endpoint: mockserver.URL, + } + + ctx := context.Background() + next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}) + + handler, err := New(ctx, next, cfg, "traefik-aws-lambda-middleware") + if err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + b := []byte("This is the body") + buf.Write(b) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/%s", mockserver.URL, "test/example/path?a=1&b=2&c=3&c=4&d[]=5&d[]=6"), &buf) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "text/plain") + req.Header.Add("X-Test", "foo") + req.Header.Add("X-Test", "foobar") + + return mockserver, handler, req +} + func Test_AWSLambdaMiddleware_Invoke(t *testing.T) { + mockserver, handler, req := setup(t, "{\"statusCode\": 0, \"body\":\"response_body\"}") + defer func() { mockserver.Close() }() + + recorder := httptest.NewRecorder() + + handler.ServeHTTP(recorder, req) + resp := recorder.Result() + rBody, _ := io.ReadAll(resp.Body) + + assert.Equal(t, []byte{}, rBody) + assert.Equal(t, 500, resp.StatusCode) +} + +func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) { mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { assert.Equal(t, http.MethodPost, req.Method) assert.Equal(t, "/2015-03-31/functions/arn%3Aaws%3Alambda%3Aus-west-2%3A000000000000%3Afunction%3Axxx%3A1/invocations", req.URL.RawPath) @@ -44,7 +113,7 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) { assert.Equal(t, "This is the body", lReq.Body) res.WriteHeader(http.StatusOK) - _, err = res.Write([]byte("{\"statusCode\": 418, \"body\":\"response_body\"}")) + _, err = res.Write([]byte("{\"statusCode\": 200, \"body\":\"response_body\"}")) if err != nil { t.Fatal(err) } @@ -67,8 +136,6 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) { t.Fatal(err) } - recorder := httptest.NewRecorder() - var buf bytes.Buffer b := []byte("This is the body") buf.Write(b) @@ -81,12 +148,14 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) { req.Header.Add("X-Test", "foo") req.Header.Add("X-Test", "foobar") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, req) resp := recorder.Result() rBody, _ := io.ReadAll(resp.Body) assert.Equal(t, []byte("response_body"), rBody) - assert.Equal(t, http.StatusTeapot, resp.StatusCode) + assert.Equal(t, http.StatusOK, resp.StatusCode) } // Test_AWSLambdaMiddleware_GetTracingInformation tests that the From 506e82a287d8e0d7c12043e34d675ded2c5dc15c Mon Sep 17 00:00:00 2001 From: Trevor Bramwell Date: Wed, 1 May 2024 13:16:35 -0700 Subject: [PATCH 2/3] AWS Lambda: Fix Testing Typos, set helper function Signed-off-by: Trevor Bramwell --- pkg/middlewares/awslambda/aws_lambda_test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/middlewares/awslambda/aws_lambda_test.go b/pkg/middlewares/awslambda/aws_lambda_test.go index 897d7eda80..c0875b9dbf 100644 --- a/pkg/middlewares/awslambda/aws_lambda_test.go +++ b/pkg/middlewares/awslambda/aws_lambda_test.go @@ -19,8 +19,9 @@ import ( ) // Setup provides a mockserver (lambda handler), and should be closed -// with 'defer func() { mockserver.Close() }()' once recieved +// with 'defer func() { mockserver.Close() }()' once received func setup(t *testing.T, response string) (*httptest.Server, http.Handler, *http.Request) { + t.Helper() mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { var buf bytes.Buffer @@ -198,7 +199,7 @@ func Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON(t *testing.T) { require.NoError(t, err) } -// Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON +// Test_AWSLambdaMiddleware_bodyToBase64_withcontent func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) { // application/zip expected := "UEsDBA==" From 9b3012b20f8dc351c3cfa45e109f056a181e9672 Mon Sep 17 00:00:00 2001 From: Trevor Bramwell Date: Wed, 1 May 2024 13:56:38 -0700 Subject: [PATCH 3/3] AWS Lambda: Fix linting issues Signed-off-by: Trevor Bramwell --- pkg/middlewares/awslambda/aws_lambda.go | 1 - pkg/middlewares/awslambda/aws_lambda_test.go | 1 - 2 files changed, 2 deletions(-) diff --git a/pkg/middlewares/awslambda/aws_lambda.go b/pkg/middlewares/awslambda/aws_lambda.go index 21661af356..6c349eb36d 100644 --- a/pkg/middlewares/awslambda/aws_lambda.go +++ b/pkg/middlewares/awslambda/aws_lambda.go @@ -281,7 +281,6 @@ func bodyToBase64(req *http.Request) (bool, string, error) { } func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) { - payload, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) diff --git a/pkg/middlewares/awslambda/aws_lambda_test.go b/pkg/middlewares/awslambda/aws_lambda_test.go index c0875b9dbf..dae2ad2ee3 100644 --- a/pkg/middlewares/awslambda/aws_lambda_test.go +++ b/pkg/middlewares/awslambda/aws_lambda_test.go @@ -23,7 +23,6 @@ import ( func setup(t *testing.T, response string) (*httptest.Server, http.Handler, *http.Request) { t.Helper() mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { - var buf bytes.Buffer _, err := io.Copy(&buf, req.Body) if err != nil {