Skip to content

Commit

Permalink
AWS Lambda: Update Lambda Response Error Handling
Browse files Browse the repository at this point in the history
Updates invokeFunction to parse lambda error responses, log the error,
and return a 502.

Signed-off-by: Trevor Bramwell <tbramwell@linuxfoundation.org>
  • Loading branch information
bramwelt committed May 1, 2024
1 parent a09fe57 commit 200b2b6
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 22 deletions.
58 changes: 40 additions & 18 deletions pkg/middlewares/awslambda/aws_lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"strings"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
Expand All @@ -67,8 +68,7 @@ type awsLambda struct {

// New builds a new AwsLambda middleware.
func New(ctx context.Context, next http.Handler, config dynamic.AWSLambda, name string) (http.Handler, error) {
logger := log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName))
logger.Debug("Creating middleware")
log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)).Debug("Creating middleware")

if len(config.FunctionArn) == 0 {
return nil, fmt.Errorf("function arn cannot be empty")
Expand Down Expand Up @@ -131,16 +131,15 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) {

// ServeHTTP is the AWS Lambda middleware that takes a request, converts
// it to an APIGatewayProxyRequest and invokes lambda. It should come at
// the end of a middlware chain as it does routing internally.
// the end of a middleware chain as it does routing internally.
// NOTE: While this could implement the same code as Lambda Invoke
// (ie: Construct request object, modify request to POST
// .../functions/..., sign request) no request middleware could be used
// afterwards as it would break the signature.
func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ctx := middlewares.GetLoggerCtx(req.Context(), a.name, typeName)
logger := log.FromContext(ctx)
logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName))

base64Encoded, body, err := bodyToBase64(req)
base64Encoded, reqBody, err := bodyToBase64(req)
if err != nil {
msg := fmt.Sprintf("Error encoding Lambda request body: %v", err)
logger.Error(msg)
Expand All @@ -161,7 +160,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
MultiValueQueryStringParameters: valuesToMultiMap(req.URL.Query()),
Headers: headersToMap(req.Header),
MultiValueHeaders: headersToMultiMap(req.Header),
Body: body,
Body: reqBody,
IsBase64Encoded: base64Encoded,
RequestContext: events.APIGatewayProxyRequestContext{
Authorizer: make(map[string]interface{}),
Expand All @@ -172,18 +171,28 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)

rw.WriteHeader(http.StatusInternalServerError)
statusCode := http.StatusInternalServerError
// If there's an error invoking the lambda, a response and error
// will be returned. Use the statuscode of the error response (502)
// to indicate a lambda error.
if resp != nil {
statusCode = resp.StatusCode
}

tracing.LogResponseCode(tracing.GetSpan(req), statusCode)
rw.WriteHeader(statusCode)
return
}

body = resp.Body
body := resp.Body
if resp.IsBase64Encoded {
buf, err := base64.StdEncoding.DecodeString(body)
if err != nil {
msg := fmt.Sprintf("Failed to base64 decode body: %s: %v", body, err)
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)

tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError)
rw.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -203,10 +212,11 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

// Validate StatusCode before writing
if !(resp.StatusCode >= 100 && resp.StatusCode < 600) {
msg := fmt.Sprintf("Invalid response status code: %d", resp.StatusCode)
msg := fmt.Sprintf("Invalid response status Code: %d; %+v", resp.StatusCode, resp)
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)

tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError)
rw.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -218,6 +228,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
msg := fmt.Sprintf("Failed to write response body %s: %v", body, err)
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)
tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError)
return
}
}
Expand Down Expand Up @@ -269,32 +280,43 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
return base64Encoded, body, nil
}

func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
var resp events.APIGatewayProxyResponse
func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) {

payload, err := json.Marshal(request)
if err != nil {
return resp, fmt.Errorf("failed to marshal request: %w", err)
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

result, err := a.client.Invoke(ctx, &lambda.InvokeInput{
FunctionName: aws.String(a.functionArn),
Payload: payload,
})
if err != nil {
return resp, err
return nil, err
}
if result == nil {
return nil, fmt.Errorf("Nil lambda result when calling %s", a.functionArn)
}

if result.StatusCode >= 300 {
return resp, fmt.Errorf("call to lambda failed with: HTTP %d", result.StatusCode)
var resp events.APIGatewayProxyResponse
// If invoking the lambda resulted in an error, return a 502 and
// set the response body to the lambda error payload
if result.FunctionError != nil || (result.StatusCode >= 300 || result.StatusCode < 200) {
resp.StatusCode = http.StatusBadGateway
var errResp messages.InvokeResponse_Error
err = json.Unmarshal(result.Payload, &errResp)
if err != nil {
return nil, fmt.Errorf("Failed to parse lambda error: %w", err)
}
return &resp, fmt.Errorf("%s: %s", errResp.Message, errResp.Type)
}

err = json.Unmarshal(result.Payload, &resp)
if err != nil {
return resp, fmt.Errorf("failed to unmarshal response: %s, %w", result.Payload, err)
return nil, fmt.Errorf("failed to unmarshal response: %s, %w", result.Payload, err)
}

return resp, nil
return &resp, nil
}

func headersToMap(h http.Header) map[string]string {
Expand Down
77 changes: 73 additions & 4 deletions pkg/middlewares/awslambda/aws_lambda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,76 @@ import (
"github.com/traefik/traefik/v2/pkg/tracing"
)

// Setup provides a mockserver (lambda handler), and should be closed
// with 'defer func() { mockserver.Close() }()' once recieved
func setup(t *testing.T, response string) (*httptest.Server, http.Handler, *http.Request) {
mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {

var buf bytes.Buffer
_, err := io.Copy(&buf, req.Body)
if err != nil {
t.Fatal(err)
}

var lReq events.APIGatewayProxyRequest
err = json.Unmarshal(buf.Bytes(), &lReq)
if err != nil {
t.Fatal(err)
}

res.WriteHeader(http.StatusOK)
_, err = res.Write([]byte(response))
if err != nil {
t.Fatal(err)
}
}))

cfg := dynamic.AWSLambda{
AccessKey: "aws-key",
Region: "us-west-2",
SecretKey: "@@not-a-key",
FunctionArn: "arn:aws:lambda:us-west-2:000000000000:function:xxx:1",
Endpoint: mockserver.URL,
}

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

handler, err := New(ctx, next, cfg, "traefik-aws-lambda-middleware")
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
b := []byte("This is the body")
buf.Write(b)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/%s", mockserver.URL, "test/example/path?a=1&b=2&c=3&c=4&d[]=5&d[]=6"), &buf)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "text/plain")
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

return mockserver, handler, req
}

func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
mockserver, handler, req := setup(t, "{\"statusCode\": 0, \"body\":\"response_body\"}")
defer func() { mockserver.Close() }()

recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, req)
resp := recorder.Result()
rBody, _ := io.ReadAll(resp.Body)

assert.Equal(t, []byte{}, rBody)
assert.Equal(t, 500, resp.StatusCode)
}

func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
assert.Equal(t, http.MethodPost, req.Method)
assert.Equal(t, "/2015-03-31/functions/arn%3Aaws%3Alambda%3Aus-west-2%3A000000000000%3Afunction%3Axxx%3A1/invocations", req.URL.RawPath)
Expand All @@ -44,7 +113,7 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
assert.Equal(t, "This is the body", lReq.Body)

res.WriteHeader(http.StatusOK)
_, err = res.Write([]byte("{\"statusCode\": 418, \"body\":\"response_body\"}"))
_, err = res.Write([]byte("{\"statusCode\": 200, \"body\":\"response_body\"}"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -67,8 +136,6 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
t.Fatal(err)
}

recorder := httptest.NewRecorder()

var buf bytes.Buffer
b := []byte("This is the body")
buf.Write(b)
Expand All @@ -81,12 +148,14 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, req)
resp := recorder.Result()
rBody, _ := io.ReadAll(resp.Body)

assert.Equal(t, []byte("response_body"), rBody)
assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

// Test_AWSLambdaMiddleware_GetTracingInformation tests that the
Expand Down

0 comments on commit 200b2b6

Please sign in to comment.