diff --git a/service/middleware/clean_multipart.go b/service/middleware/clean_multipart.go new file mode 100644 index 0000000..1a9c3e2 --- /dev/null +++ b/service/middleware/clean_multipart.go @@ -0,0 +1,20 @@ +package middleware + +import ( + "log" + "net/http" +) + +func CleanMultipart(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + + if r.MultipartForm == nil { + return + } + if err := r.MultipartForm.RemoveAll(); err != nil { + id, _ := GetRequestID(r) + log.Printf("%s: failed to cleanup multipart/form-data residues: %v", id, err) + } + }) +} diff --git a/service/middleware/logging.go b/service/middleware/logging.go new file mode 100644 index 0000000..9c67f38 --- /dev/null +++ b/service/middleware/logging.go @@ -0,0 +1,82 @@ +package middleware + +import ( + "bytes" + "net" + "net/http" + "os" + "strconv" +) + +type responseLogger struct { + status, n int + http.ResponseWriter +} + +func (l *responseLogger) Write(b []byte) (n int, err error) { + n, err = l.ResponseWriter.Write(b) + l.n += n + return +} + +func (l *responseLogger) WriteHeader(status int) { + l.ResponseWriter.WriteHeader(status) + l.status = status +} + +// Logging performs request logging. This method takes heavy inspiration +// from (github.com/gorilla/handlers).CombinedLoggingHandler. +func Logging(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + rl := &responseLogger{ResponseWriter: w} + url := *r.URL + + next.ServeHTTP(rl, r) + + username := "-" + if url.User != nil { + if name := url.User.Username(); name != "" { + username = name + } + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + host = r.RemoteAddr + } + + uri := r.RequestURI + + // Requests using the CONNECT method over HTTP/2.0 must use + // the authority field (aka r.Host) to identify the target. + // Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT + if r.ProtoMajor == 2 && r.Method == "CONNECT" { + uri = r.Host + } + if uri == "" { + uri = url.RequestURI() + } + + id, haveID := GetRequestID(r) + + var buf bytes.Buffer + buf.Grow(len(id) + len(r.Method) + len(host) + len(username) + len(uri) + 50) // reduce allocs + if haveID { + buf.WriteString(id) + buf.WriteString(": ") + } + buf.WriteString(host) + buf.WriteString(" - ") + buf.WriteString(username) + buf.WriteString(" - ") + buf.WriteString(r.Method) + buf.WriteByte(' ') + buf.WriteString(uri) + buf.WriteString(" - ") + buf.WriteString(strconv.Itoa(rl.status)) + buf.WriteByte(' ') + buf.WriteString(strconv.Itoa(rl.n)) + buf.WriteByte('\n') + _, _ = buf.WriteTo(os.Stderr) + }) +} diff --git a/requestid/requestid.go b/service/middleware/requestid.go similarity index 68% rename from requestid/requestid.go rename to service/middleware/requestid.go index f4dcc72..71764ca 100644 --- a/requestid/requestid.go +++ b/service/middleware/requestid.go @@ -1,4 +1,4 @@ -package requestid +package middleware import ( "context" @@ -13,7 +13,7 @@ type contextKey string const ContextKey = contextKey("request-id") -func Middleware(next http.Handler) http.Handler { +func RequestID(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := uuid.Must(uuid.NewRandom()).String() r = r.WithContext(context.WithValue(r.Context(), ContextKey, id)) @@ -22,3 +22,8 @@ func Middleware(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +func GetRequestID(r *http.Request) (string, bool) { + id, ok := r.Context().Value(ContextKey).(string) + return id, ok +} diff --git a/requestid/requestid_test.go b/service/middleware/requestid_test.go similarity index 91% rename from requestid/requestid_test.go rename to service/middleware/requestid_test.go index 7d78343..4d4e8fb 100644 --- a/requestid/requestid_test.go +++ b/service/middleware/requestid_test.go @@ -1,4 +1,4 @@ -package requestid +package middleware import ( "net/http" @@ -26,7 +26,7 @@ func TestMiddleware(t *testing.T) { } w := httptest.NewRecorder() - Middleware(http.HandlerFunc(captureContextID)).ServeHTTP(w, + RequestID(http.HandlerFunc(captureContextID)).ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil)) assert.Equal(t, http.StatusOK, w.Code) diff --git a/service/service.go b/service/service.go index d6414fd..b77f7e6 100644 --- a/service/service.go +++ b/service/service.go @@ -6,11 +6,10 @@ import ( "fmt" "log" "net/http" - "os" "time" "github.com/digineo/texd/exec" - "github.com/digineo/texd/requestid" + "github.com/digineo/texd/service/middleware" "github.com/digineo/texd/tex" "github.com/gorilla/handlers" "github.com/gorilla/mux" @@ -62,12 +61,14 @@ func Start(opts Options) func(context.Context) error { r.HandleFunc("/metrics", svc.HandleMetrics).Methods(http.MethodGet) // r.Use(handlers.RecoveryHandler()) - r.Use(requestid.Middleware) + r.Use(middleware.RequestID) r.Use(handlers.CompressHandler) + r.Use(middleware.Logging) + r.Use(middleware.CleanMultipart) srv := http.Server{ Addr: opts.Addr, - Handler: handlers.CombinedLoggingHandler(os.Stdout, r), + Handler: r, } go func() {