From 3dd5c4de71d77d6889b900bdbd9eb7b2b0f7ea56 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Wed, 14 Oct 2020 10:40:46 +0200 Subject: [PATCH] Prevent multiple `Set-Cookie` headers when calling RegenerateToken Closes #61 --- context.go | 17 +++++++++++++++++ context_legacy.go | 32 ++++++++++++++++++++++++++++++++ handler.go | 11 +++++++++++ handler_go17_test.go | 17 +++++++++++++++++ handler_legacy_test.go | 18 ++++++++++++++++++ handler_test.go | 16 ---------------- 6 files changed, 95 insertions(+), 16 deletions(-) diff --git a/context.go b/context.go index d641845..1154135 100644 --- a/context.go +++ b/context.go @@ -16,6 +16,12 @@ type csrfContext struct { token string // reason for the failure of CSRF check reason error + // wasSent is true if `Set-Cookie` was called + // for the `name=csrf_token` already. This prevents + // duplicate `Set-Cookie: csrf_token` headers. + // For more information see: + // https://github.com/justinas/nosurf/pull/61 + wasSent bool } // Token takes an HTTP request and returns @@ -53,6 +59,17 @@ func ctxSetToken(req *http.Request, token []byte) { ctx.token = b64encode(maskToken(token)) } +func ctxSetSent(req *http.Request) { + ctx := req.Context().Value(nosurfKey).(*csrfContext) + ctx.wasSent = true +} + +func ctxWasSent(req *http.Request) bool { + ctx := req.Context().Value(nosurfKey).(*csrfContext) + + return ctx.wasSent +} + func ctxSetReason(req *http.Request, reason error) { ctx := req.Context().Value(nosurfKey).(*csrfContext) if ctx.token == "" { diff --git a/context_legacy.go b/context_legacy.go index 81e1b89..6f2e3ae 100644 --- a/context_legacy.go +++ b/context_legacy.go @@ -17,6 +17,12 @@ type csrfContext struct { token string // reason for the failure of CSRF check reason error + // wasSent is true if `Set-Cookie` was called + // for the `name=csrf_token` already. This prevents + // duplicate `Set-Cookie: csrf_token` headers. + // For more information see: + // https://github.com/justinas/nosurf/pull/61 + wasSent bool } var ( @@ -79,6 +85,32 @@ func ctxSetToken(req *http.Request, token []byte) *http.Request { return req } +func ctxSetSent(req *http.Request) { + cmMutex.Lock() + defer cmMutex.Unlock() + + ctx, ok := contextMap[req] + if !ok { + ctx = new(csrfContext) + contextMap[req] = ctx + } + + ctx.wasSent = true +} + +func ctxWasSent(req *http.Request) bool { + cmMutex.RLock() + defer cmMutex.RUnlock() + + ctx, ok := contextMap[req] + + if !ok { + return false + } + + return ctx.wasSent +} + func ctxSetReason(req *http.Request, reason error) *http.Request { cmMutex.Lock() defer cmMutex.Unlock() diff --git a/handler.go b/handler.go index f0a09a2..e6c7930 100644 --- a/handler.go +++ b/handler.go @@ -195,6 +195,16 @@ func (h *CSRFHandler) handleFailure(w http.ResponseWriter, r *http.Request) { // Generates a new token, sets it on the given request and returns it func (h *CSRFHandler) RegenerateToken(w http.ResponseWriter, r *http.Request) string { + if ctxWasSent(r) { + // The CSRF Cookie was set already by an earlier call to `RegenerateToken` + // in the same request context. It therefore does not make sense to regenerate + // it again as it will lead to two or more `Set-Cookie` instructions which will in turn + // cause CSRF to fail depending on the resulting order of the `Set-Cookie` instructions. + // + // No warning is necessary as the only caller to `setTokenCookie` is `RegenerateToken`. + return Token(r) + } + token := generateToken() h.setTokenCookie(w, r, token) @@ -210,6 +220,7 @@ func (h *CSRFHandler) setTokenCookie(w http.ResponseWriter, r *http.Request, tok cookie.Value = b64encode(token) http.SetCookie(w, &cookie) + ctxSetSent(r) } diff --git a/handler_go17_test.go b/handler_go17_test.go index 09ae72b..c2b7d76 100644 --- a/handler_go17_test.go +++ b/handler_go17_test.go @@ -28,3 +28,20 @@ func TestContextIsAccessibleWithContext(t *testing.T) { hand.ServeHTTP(writer, req) } + +func TestNoDoubleCookie(t *testing.T) { + var n *CSRFHandler + n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n.RegenerateToken(w, r) + })) + + r := httptest.NewRequest("GET", "http://dummy.us", nil) + w := httptest.NewRecorder() + + n.ServeHTTP(w, r) + + count := len(w.Result().Cookies()) + if count > 1 { + t.Errorf("Expected one CSRF cookie, got %d", count) + } +} diff --git a/handler_legacy_test.go b/handler_legacy_test.go index 56909f9..963551c 100644 --- a/handler_legacy_test.go +++ b/handler_legacy_test.go @@ -5,6 +5,7 @@ package nosurf import ( "net/http" "net/http/httptest" + "strings" "testing" ) @@ -20,3 +21,20 @@ func TestClearsContextAfterTheRequest(t *testing.T) { t.Errorf("Instead, the context entry remains: %v", contextMap[req]) } } + +func TestNoDoubleCookie(t *testing.T) { + var n *CSRFHandler + n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n.RegenerateToken(w, r) + })) + + r := httptest.NewRequest("GET", "http://dummy.us", nil) + w := httptest.NewRecorder() + + n.ServeHTTP(w, r) + + count := strings.Count(w.HeaderMap.Get("Set-Cookie"), "csrf_token") + if count > 1 { + t.Errorf("Expected one CSRF cookie, got %d", count) + } +} diff --git a/handler_test.go b/handler_test.go index 087353e..128a09b 100644 --- a/handler_test.go +++ b/handler_test.go @@ -9,22 +9,6 @@ import ( "testing" ) -func TestNoDoubleCookie(t *testing.T) { - var n *CSRFHandler - n = New(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n.RegenerateToken(w, r) - })) - - r := httptest.NewRequest("GET", "http://dummy.us", nil) - w := httptest.NewRecorder() - - n.ServeHTTP(w, r) - - if len(w.Result().Cookies()) > 1 { - t.Errorf("Expected one CSRF cookie, got %d", len(w.Result().Cookies())) - } -} - func TestDefaultFailureHandler(t *testing.T) { writer := httptest.NewRecorder() req := dummyGet()