Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS Lambda: Update Lambda Response Error Handling #6

Merged
merged 3 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions pkg/middlewares/awslambda/aws_lambda.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"strings"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda/messages"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
Expand All @@ -67,8 +68,7 @@ type awsLambda struct {

// New builds a new AwsLambda middleware.
func New(ctx context.Context, next http.Handler, config dynamic.AWSLambda, name string) (http.Handler, error) {
logger := log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName))
logger.Debug("Creating middleware")
log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)).Debug("Creating middleware")

if len(config.FunctionArn) == 0 {
return nil, fmt.Errorf("function arn cannot be empty")
Expand Down Expand Up @@ -131,16 +131,15 @@ func (a *awsLambda) GetTracingInformation() (string, ext.SpanKindEnum) {

// ServeHTTP is the AWS Lambda middleware that takes a request, converts
// it to an APIGatewayProxyRequest and invokes lambda. It should come at
// the end of a middlware chain as it does routing internally.
// the end of a middleware chain as it does routing internally.
// NOTE: While this could implement the same code as Lambda Invoke
// (ie: Construct request object, modify request to POST
// .../functions/..., sign request) no request middleware could be used
// afterwards as it would break the signature.
func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
ctx := middlewares.GetLoggerCtx(req.Context(), a.name, typeName)
logger := log.FromContext(ctx)
logger := log.FromContext(middlewares.GetLoggerCtx(req.Context(), a.name, typeName))

base64Encoded, body, err := bodyToBase64(req)
base64Encoded, reqBody, err := bodyToBase64(req)
if err != nil {
msg := fmt.Sprintf("Error encoding Lambda request body: %v", err)
logger.Error(msg)
Expand All @@ -161,7 +160,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
MultiValueQueryStringParameters: valuesToMultiMap(req.URL.Query()),
Headers: headersToMap(req.Header),
MultiValueHeaders: headersToMultiMap(req.Header),
Body: body,
Body: reqBody,
IsBase64Encoded: base64Encoded,
RequestContext: events.APIGatewayProxyRequestContext{
Authorizer: make(map[string]interface{}),
Expand All @@ -172,18 +171,28 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)

rw.WriteHeader(http.StatusInternalServerError)
statusCode := http.StatusInternalServerError
// If there's an error invoking the lambda, a response and error
// will be returned. Use the statuscode of the error response (502)
// to indicate a lambda error.
if resp != nil {
statusCode = resp.StatusCode
}

tracing.LogResponseCode(tracing.GetSpan(req), statusCode)
rw.WriteHeader(statusCode)
return
}

body = resp.Body
body := resp.Body
if resp.IsBase64Encoded {
buf, err := base64.StdEncoding.DecodeString(body)
if err != nil {
msg := fmt.Sprintf("Failed to base64 decode body: %s: %v", body, err)
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)

tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError)
rw.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -207,6 +216,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)

tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError)
rw.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -218,6 +228,7 @@ func (a *awsLambda) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
msg := fmt.Sprintf("Failed to write response body %s: %v", body, err)
logger.Error(msg)
tracing.SetErrorWithEvent(req, msg)
tracing.LogResponseCode(tracing.GetSpan(req), http.StatusInternalServerError)
return
}
}
Expand Down Expand Up @@ -269,32 +280,42 @@ func bodyToBase64(req *http.Request) (bool, string, error) {
return base64Encoded, body, nil
}

func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
var resp events.APIGatewayProxyResponse

func (a *awsLambda) invokeFunction(ctx context.Context, request events.APIGatewayProxyRequest) (*events.APIGatewayProxyResponse, error) {
payload, err := json.Marshal(request)
if err != nil {
return resp, fmt.Errorf("failed to marshal request: %w", err)
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

result, err := a.client.Invoke(ctx, &lambda.InvokeInput{
FunctionName: aws.String(a.functionArn),
Payload: payload,
})
if err != nil {
return resp, err
return nil, err
}
if result == nil {
return nil, fmt.Errorf("Nil lambda result when calling %s", a.functionArn)
}

if result.StatusCode >= 300 {
return resp, fmt.Errorf("call to lambda failed with: HTTP %d", result.StatusCode)
var resp events.APIGatewayProxyResponse
// If invoking the lambda resulted in an error, return a 502 and
// set the response body to the lambda error payload
if result.FunctionError != nil || (result.StatusCode >= 300 || result.StatusCode < 200) {
resp.StatusCode = http.StatusBadGateway
var errResp messages.InvokeResponse_Error
err = json.Unmarshal(result.Payload, &errResp)
if err != nil {
return nil, fmt.Errorf("Failed to parse lambda error: %w", err)
}
return &resp, fmt.Errorf("%s: %s", errResp.Message, errResp.Type)
}

err = json.Unmarshal(result.Payload, &resp)
if err != nil {
return resp, fmt.Errorf("failed to unmarshal response: %s, %w", result.Payload, err)
return nil, fmt.Errorf("failed to unmarshal response: %s, %w", result.Payload, err)
}

return resp, nil
return &resp, nil
}

func headersToMap(h http.Header) map[string]string {
Expand Down
79 changes: 74 additions & 5 deletions pkg/middlewares/awslambda/aws_lambda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,76 @@ import (
"github.com/traefik/traefik/v2/pkg/tracing"
)

// Setup provides a mockserver (lambda handler), and should be closed
// with 'defer func() { mockserver.Close() }()' once received
func setup(t *testing.T, response string) (*httptest.Server, http.Handler, *http.Request) {
t.Helper()
mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
var buf bytes.Buffer
_, err := io.Copy(&buf, req.Body)
if err != nil {
t.Fatal(err)
}

var lReq events.APIGatewayProxyRequest
err = json.Unmarshal(buf.Bytes(), &lReq)
if err != nil {
t.Fatal(err)
}

res.WriteHeader(http.StatusOK)
_, err = res.Write([]byte(response))
if err != nil {
t.Fatal(err)
}
}))

cfg := dynamic.AWSLambda{
AccessKey: "aws-key",
Region: "us-west-2",
SecretKey: "@@not-a-key",
FunctionArn: "arn:aws:lambda:us-west-2:000000000000:function:xxx:1",
Endpoint: mockserver.URL,
}

ctx := context.Background()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

handler, err := New(ctx, next, cfg, "traefik-aws-lambda-middleware")
if err != nil {
t.Fatal(err)
}

var buf bytes.Buffer
b := []byte("This is the body")
buf.Write(b)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/%s", mockserver.URL, "test/example/path?a=1&b=2&c=3&c=4&d[]=5&d[]=6"), &buf)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "text/plain")
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

return mockserver, handler, req
}

func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
mockserver, handler, req := setup(t, "{\"statusCode\": 0, \"body\":\"response_body\"}")
defer func() { mockserver.Close() }()

recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, req)
resp := recorder.Result()
rBody, _ := io.ReadAll(resp.Body)

assert.Equal(t, []byte{}, rBody)
assert.Equal(t, 500, resp.StatusCode)
}

func Test_AWSLambdaMiddleware_InvokeBasic(t *testing.T) {
mockserver := httptest.NewServer(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
assert.Equal(t, http.MethodPost, req.Method)
assert.Equal(t, "/2015-03-31/functions/arn%3Aaws%3Alambda%3Aus-west-2%3A000000000000%3Afunction%3Axxx%3A1/invocations", req.URL.RawPath)
Expand All @@ -44,7 +113,7 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
assert.Equal(t, "This is the body", lReq.Body)

res.WriteHeader(http.StatusOK)
_, err = res.Write([]byte("{\"statusCode\": 418, \"body\":\"response_body\"}"))
_, err = res.Write([]byte("{\"statusCode\": 200, \"body\":\"response_body\"}"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -67,8 +136,6 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
t.Fatal(err)
}

recorder := httptest.NewRecorder()

var buf bytes.Buffer
b := []byte("This is the body")
buf.Write(b)
Expand All @@ -81,12 +148,14 @@ func Test_AWSLambdaMiddleware_Invoke(t *testing.T) {
req.Header.Add("X-Test", "foo")
req.Header.Add("X-Test", "foobar")

recorder := httptest.NewRecorder()

handler.ServeHTTP(recorder, req)
resp := recorder.Result()
rBody, _ := io.ReadAll(resp.Body)

assert.Equal(t, []byte("response_body"), rBody)
assert.Equal(t, http.StatusTeapot, resp.StatusCode)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

// Test_AWSLambdaMiddleware_GetTracingInformation tests that the
Expand Down Expand Up @@ -129,7 +198,7 @@ func Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON(t *testing.T) {
require.NoError(t, err)
}

// Test_AWSLambdaMiddleware_bodyToBase64_notEncodedJSON
// Test_AWSLambdaMiddleware_bodyToBase64_withcontent
func Test_AWSLambdaMiddleware_bodyToBase64_withcontent(t *testing.T) {
// application/zip
expected := "UEsDBA=="
Expand Down
Loading