Skip to content
This repository has been archived by the owner on Dec 9, 2022. It is now read-only.

Commit

Permalink
refactor: http.ResponseWriter 인터페이스 표준화, 어뎁터 구현체 추가
Browse files Browse the repository at this point in the history
  • Loading branch information
2rebi committed Oct 12, 2022
1 parent 20c844b commit 1ef91b6
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 350 deletions.
37 changes: 20 additions & 17 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"encoding/xml"
"errors"
"github.com/aws/aws-lambda-go/events"
"io"
"net"
Expand Down Expand Up @@ -34,6 +35,8 @@ type Context interface {

RequestBodyBytes() []byte

SetResponse(r *Response)

Response() *Response

Scheme() string
Expand Down Expand Up @@ -88,6 +91,8 @@ type Context interface {

Write(b []byte) (int, error)

Redirect(status int, url string) error

// TODO
//Logger() log.Logger

Expand Down Expand Up @@ -180,25 +185,16 @@ func (c *contextImpl) RequestBodyBytes() []byte {
return c.requestBodyBytes
}

func (c *contextImpl) SetResponse(r *Response) {
c.response = r
}

func (c *contextImpl) Response() *Response {
return c.response
}

func (c *contextImpl) Scheme() string {
header := c.request.Header
if scheme := header.Get(HeaderXForwardedProto); scheme != "" {
return scheme
}
if scheme := header.Get(HeaderXForwardedProtocol); scheme != "" {
return scheme
}
if ssl := header.Get(HeaderXForwardedSsl); ssl == "on" {
return "https"
}
if scheme := header.Get(HeaderXUrlScheme); scheme != "" {
return scheme
}
return "http"
return getSchemeFromHeader(c.request.Header)
}

func (c *contextImpl) RealIP() string {
Expand Down Expand Up @@ -343,9 +339,6 @@ func (c *contextImpl) XMLBytes(status int, b []byte) (err error) {

func (c *contextImpl) ResultStream(status int, contentType string, reader io.Reader) (err error) {
c.writeContentType(contentType)
if len(contentType) > 0 {
c.Response().isBinary = !notBinaryTable[contentType]
}
c.Response().WriteHeader(status)
_, err = io.Copy(c.Response(), reader)
return
Expand All @@ -359,6 +352,16 @@ func (c *contextImpl) Write(b []byte) (int, error) {
return c.Response().Write(b)
}

func (c *contextImpl) Redirect(status int, url string) error {
if status < http.StatusMultipleChoices ||
status > http.StatusPermanentRedirect {
return errors.New("invalid redirect status code")
}
c.Response().Header().Set(HeaderLocation, url)
c.Response().WriteHeader(status)
return nil
}

func (c *contextImpl) Golam() *Golam {
return c.golam
}
Expand Down
210 changes: 81 additions & 129 deletions golam.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@ import (
"errors"
"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambda"
"io"
"net/http"
"net/url"
"strconv"
"strings"
)

const (
lambdaHeaderContentLength = "content-length"
lambdaHeaderXForwardedProto = "x-forwarded-proto"
lambdaHeaderHost = "host"
)

func New() (g *Golam) {
g = &Golam{
isLambdaRuntime: isLambdaRuntime(),
Expand Down Expand Up @@ -86,9 +79,9 @@ func (d *defaultHttpHandler) ServeHTTP(writer http.ResponseWriter, request *http

ctxImpl := &contextImpl{
request: request,
response: &Response{
header: make(http.Header),
},
response: NewResponse(
NewResponseHTTPAdapter(writer),
),
path: reqPath,
query: request.URL.Query(),
golam: d.golam,
Expand Down Expand Up @@ -143,97 +136,55 @@ func (d *defaultHttpHandler) ServeHTTP(writer http.ResponseWriter, request *http
// TODO error handle
switch err := ctxImpl.handler(ctxImpl); err {
case nil:
resp := ctxImpl.Response()
resp.Committed = true

for k, values := range resp.header {
for _, val := range values {
writer.Header().Add(k, val)
}
}

for _, cookie := range resp.cookies {
http.SetCookie(writer, cookie)
err = ctxImpl.Response().Commit()
if err != nil {
panic(err) // unreachable code
}

writer.WriteHeader(resp.StatusCode)
_, _ = io.Copy(writer, &resp.BodyBuffer)
default:
// TODO error handler
}
}

func (d *defaultLambdaHandler) Invoke(ctx context.Context, payload []byte) ([]byte, error) {
var request events.APIGatewayV2HTTPRequest
err := json.Unmarshal(payload, &request)
var lReq events.APIGatewayV2HTTPRequest
err := json.Unmarshal(payload, &lReq)
if err != nil {
//TODO: logging
return nil, err
}

method := request.RequestContext.HTTP.Method
var rawURL strings.Builder
rawURL.WriteString(request.Headers[lambdaHeaderXForwardedProto])
rawURL.WriteString("://")
rawURL.WriteString(request.Headers[lambdaHeaderHost])
rawURL.WriteString(request.RawPath)
if len(request.RawQueryString) > 0 {
rawURL.WriteString("?")
rawURL.WriteString(request.RawQueryString)
}

var body []byte
if request.IsBase64Encoded {
body, err = base64.StdEncoding.DecodeString(request.Body)
if err != nil {
//TODO: logging
return nil, err
}
} else {
body = []byte(request.Body)
}

req, err := http.NewRequestWithContext(ctx, method, rawURL.String(), bytes.NewReader(body))
req, err := newHTTPRequestFromAPIGatewayV2HTTPRequest(ctx, &lReq)
if err != nil {
//TODO: logging
return nil, err
}

req.ProtoMajor, req.ProtoMinor = getProtoVersion(&request)
req.RemoteAddr = request.RequestContext.HTTP.SourceIP
for k, v := range request.Headers {
if req.Header.Get(k) == v {
continue
}

req.Header.Set(k, v)
}

var response events.APIGatewayV2HTTPResponse
ctxImpl := &contextImpl{
request: req,
response: &Response{
header: make(http.Header),
},
path: request.RequestContext.HTTP.Path,
response: NewResponse(
NewResponseLambdaAdapter(&response),
),
path: lReq.RequestContext.HTTP.Path,
golam: d.golam,
primalRequestLambda: &request,
primalRequestLambda: &lReq,
}

ctxImpl.query, _ = url.ParseQuery(request.RawQueryString)
ctxImpl.query, _ = url.ParseQuery(lReq.RawQueryString)

pathKey := request.RouteKey[strings.Index(request.RouteKey, " ")+1:]
pathKey := lReq.RouteKey[strings.Index(lReq.RouteKey, " ")+1:]
r := d.golam.Router().FindRoute(pathKey)
if r != nil {
handler := r.handlers.getHandler(method)
handler := r.handlers.getHandler(req.Method)
if handler != nil {
ctxImpl.handler = handler
} else {
ctxImpl.handler = r.handlers.getHandler("")
}

if len(request.PathParameters) > 0 {
if len(lReq.PathParameters) > 0 {
ctxImpl.pathParams = make(PathParams)
for k, v := range request.PathParameters {
for k, v := range lReq.PathParameters {
ctxImpl.pathParams[k] = PathParam{
Key: k,
Value: v,
Expand All @@ -249,39 +200,13 @@ func (d *defaultLambdaHandler) Invoke(ctx context.Context, payload []byte) ([]by
ctxImpl.handler = wrapMiddleware(ctxImpl.handler, append(d.golam.preMiddleware, d.golam.middleware...)...)
}

var response events.APIGatewayV2HTTPResponse
switch err = ctxImpl.handler(ctxImpl); err {
case nil:
resp := ctxImpl.Response()
resp.Committed = true

response = events.APIGatewayV2HTTPResponse{
StatusCode: resp.StatusCode,
Headers: make(map[string]string),
MultiValueHeaders: make(map[string][]string),
IsBase64Encoded: resp.isBinary,
Body: resp.BodyBuffer.String(),
}

if response.IsBase64Encoded {
response.Body = base64.StdEncoding.EncodeToString(resp.BodyBuffer.Bytes())
} else {
response.Body = resp.BodyBuffer.String()
}

for k, v := range resp.header {
if len(v) > 1 {
response.MultiValueHeaders[k] = v
} else {
response.Headers[k] = v[0]
}
}

response.Cookies = make([]string, 0, len(resp.cookies))
for i := range resp.cookies {
response.Cookies = append(response.Cookies, resp.cookies[i].String())
err = ctxImpl.Response().Commit()
if err != nil {
//TODO: logging
return nil, err
}

default:
// TODO error handler
}
Expand All @@ -306,36 +231,6 @@ func (g *Golam) Router() Router {
return g.router
}

func getProtoVersion(request *events.APIGatewayV2HTTPRequest) (major, minor int) {
if request == nil {
return
}
rawProtocol := request.RequestContext.HTTP.Protocol
div := strings.Index(rawProtocol, "/") + 1
if div == -1 {
return
}

proto := rawProtocol[div:]
div = strings.Index(proto, ".")
if div == -1 {
return
}

major, _ = strconv.Atoi(proto[:div])
minor, _ = strconv.Atoi(proto[div+1:])
return
}

func getContentLength(request *events.APIGatewayV2HTTPRequest) (l int64) {
if request == nil {
return
}

l, _ = strconv.ParseInt(request.Headers[lambdaHeaderContentLength], 10, 0)
return
}

func (g *Golam) Pre(middleware ...MiddlewareFunc) {
g.preMiddleware = append(g.preMiddleware, middleware...)
}
Expand Down Expand Up @@ -389,3 +284,60 @@ const defaultNotFoundResponseData = "{\"message\":\"Not Found\"}"
func DefaultNotFound(ctx Context) error {
return ctx.JSONBytes(http.StatusNotFound, []byte(defaultNotFoundResponseData))
}

func newHTTPRequestFromAPIGatewayV2HTTPRequest(ctx context.Context, from *events.APIGatewayV2HTTPRequest) (req *http.Request, err error) {
rawProtocol := from.RequestContext.HTTP.Protocol
protoDiv := strings.Index(rawProtocol, "/")
if protoDiv == -1 {
return
}

proto := rawProtocol[protoDiv+1:]
protoDiv = strings.Index(proto, ".")
if protoDiv == -1 {
return
}

major, err := strconv.Atoi(proto[:protoDiv])
if err != nil {
return
}
minor, err := strconv.Atoi(proto[protoDiv+1:])
if err != nil {
return
}

var body []byte
if from.IsBase64Encoded {
body, err = base64.StdEncoding.DecodeString(from.Body)
if err != nil {
return
}
} else {
body = []byte(from.Body)
}

header := getHeaderFromAPIGatewayV2HTTPRequest(from)

var rawURL strings.Builder
rawURL.WriteString(getSchemeFromHeader(header))
rawURL.WriteString("://")
rawURL.WriteString(header.Get(HeaderHost))
rawURL.WriteString(from.RawPath)
if len(from.RawQueryString) > 0 {
rawURL.WriteRune('?')
rawURL.WriteString(from.RawQueryString)
}

req, err = http.NewRequestWithContext(ctx, from.RequestContext.HTTP.Method, rawURL.String(), bytes.NewReader(body))
if err != nil {
return
}

req.Proto = rawProtocol
req.ProtoMajor = major
req.ProtoMinor = minor
req.RemoteAddr = from.RequestContext.HTTP.SourceIP
req.Header = header
return
}
Loading

0 comments on commit 1ef91b6

Please sign in to comment.