From d2a8dc8add4b3246b4c8093d5e146eb79b3c86f7 Mon Sep 17 00:00:00 2001
From: vivekweb2013 <vivekweb2013@gmail.com>
Date: Thu, 26 May 2022 15:38:28 +0530
Subject: [PATCH] Enable cors to support external client

---
 config.yaml                                |  3 +-
 go.mod                                     |  1 +
 go.sum                                     |  8 +++
 internal/config/config.go                  |  1 +
 internal/httpservice/login_handler.go      | 69 ++++++++++++++--------
 internal/httpservice/login_handler_test.go | 53 +++++++++++++++--
 internal/httpservice/router.go             | 34 ++++++++++-
 7 files changed, 138 insertions(+), 31 deletions(-)

diff --git a/config.yaml b/config.yaml
index ff7c7a4..a6d3f45 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,5 +1,6 @@
 app:
   secretKey: secret
+  clientURL: "http://localhost:3000"
 
 database:
   host: localhost
@@ -19,4 +20,4 @@ oAuth2:
   github:
     clientID: "<GITHUB_CLIENT_ID>"
     clientSecret: "<GITHUB_CLIENT_SECRET>"
-    redirectURL: "http://localhost:3000/api/v1/oauth2/github/callback"
+    redirectURL: "http://localhost:8080/api/v1/oauth2/github/callback"
diff --git a/go.mod b/go.mod
index 0a6a0c9..5cace33 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.18
 require github.com/spf13/cobra v1.4.0
 
 require (
+	github.com/gin-contrib/cors v1.3.1
 	github.com/gin-gonic/gin v1.7.7
 	github.com/go-ozzo/ozzo-validation v3.6.0+incompatible
 	github.com/google/go-github/v43 v43.0.0
diff --git a/go.sum b/go.sum
index 040239d..b638ce6 100644
--- a/go.sum
+++ b/go.sum
@@ -354,8 +354,11 @@ github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmx
 github.com/garyburd/redigo v0.0.0-20150301180006-535138d7bcd7/go.mod h1:NR3MbYisc3/PwhQ00EMzDiPmrwpPxAn5GI05/YaO1SY=
 github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
 github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
+github.com/gin-contrib/cors v1.3.1 h1:doAsuITavI4IOcd0Y19U4B+O0dNWihRyX//nn4sEmgA=
+github.com/gin-contrib/cors v1.3.1/go.mod h1:jjEJ4268OPZUcU7k9Pm653S7lXUGcqMADzFA61xsmDk=
 github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
 github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
+github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
 github.com/gin-gonic/gin v1.7.7 h1:3DoBmSbJbZAWqXJC3SLjAPfutPJJRN1U5pALB7EeTTs=
 github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ1qq1U=
 github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g=
@@ -386,8 +389,10 @@ github.com/go-ozzo/ozzo-validation v3.6.0+incompatible h1:msy24VGS42fKO9K1vLz82/
 github.com/go-ozzo/ozzo-validation v3.6.0+incompatible/go.mod h1:gsEKFIVnabGBt6mXmxK0MoFy+cZoTJY6mu5Ll3LVLBU=
 github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
 github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
+github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
 github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q=
 github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
+github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
 github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no=
 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA=
 github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE=
@@ -710,6 +715,7 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
 github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 github.com/ktrysmt/go-bitbucket v0.6.4/go.mod h1:9u0v3hsd2rqCHRIpbir1oP7F58uo5dq19sBYvuMoyQ4=
+github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
 github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
 github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
 github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
@@ -1547,6 +1553,8 @@ gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qS
 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
 gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
 gopkg.in/gemnasium/logrus-airbrake-hook.v2 v2.1.2/go.mod h1:Xk6kEKp8OKb+X14hQBKWaSkCsqBpgog8nAV2xsGOxlo=
+gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
+gopkg.in/go-playground/validator.v9 v9.29.1/go.mod h1:+c9/zcJMFNgbLvly1L1V+PpxWdVbfP1avr/N00E2vyQ=
 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
 gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
 gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
diff --git a/internal/config/config.go b/internal/config/config.go
index 586ec90..bcbd3d1 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -3,6 +3,7 @@ package config
 // App represents configuration properties specific to the application.
 type App struct {
 	SecretKey string
+	ClientURL string
 }
 
 // Database represents configuration properties required to connect to a database.
diff --git a/internal/httpservice/login_handler.go b/internal/httpservice/login_handler.go
index 8580ea9..1cb5c94 100644
--- a/internal/httpservice/login_handler.go
+++ b/internal/httpservice/login_handler.go
@@ -15,17 +15,19 @@ import (
 
 // LoginHandler represents http handler for serving user login actions.
 type LoginHandler struct {
-	authService  auth.Service
-	githubServie github.Service
-	userService  user.Service
+	authService   auth.Service
+	githubService github.Service
+	userService   user.Service
+	clientURL     string
 }
 
 // NewLoginHandler creates and returns a new login handler.
-func NewLoginHandler(authService auth.Service, githubServie github.Service, userService user.Service) *LoginHandler {
+func NewLoginHandler(authService auth.Service, githubService github.Service, userService user.Service, clientURL string) *LoginHandler {
 	return &LoginHandler{
-		authService:  authService,
-		githubServie: githubServie,
-		userService:  userService,
+		authService:   authService,
+		githubService: githubService,
+		userService:   userService,
+		clientURL:     clientURL,
 	}
 }
 
@@ -34,7 +36,7 @@ func (l *LoginHandler) GithubLogin(c *gin.Context) {
 	state := uuid.NewString()
 	c.SetCookie("state", state, 600, "/", "", true, true)
 
-	url := l.githubServie.GetAuthCodeURL(state)
+	url := l.githubService.GetAuthCodeURL(state)
 
 	// trigger authorization code grant flow
 	c.Redirect(http.StatusTemporaryRedirect, url)
@@ -42,31 +44,30 @@ func (l *LoginHandler) GithubLogin(c *gin.Context) {
 
 // GithubOAuth2Callback processes github oauth2 callback.
 // It validates the state, fetch token and user from github, stores the user to db, generates app token.
-// A response containing app token is sent to the client.
+// The app token will be sent as token cookie with a redirect to client url.
 func (l *LoginHandler) GithubOAuth2Callback(c *gin.Context) {
 	logrus.Info("github oauth2 callback started")
 	state, _ := c.Cookie("state")
 	stateFromCallback := c.Query("state")
 	code := c.Query("code")
-	failRedirectPath := "/?login_error=true"
 
 	if stateFromCallback != state {
 		logrus.Error("invalid oauth state")
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=invalid-state")
 		return
 	}
 
-	githubToken, err := l.githubServie.GetToken(c, code)
+	githubToken, err := l.githubService.GetToken(c, code)
 	if err != nil {
 		logrus.Errorf("auth code exchange for token failed: %s", err.Error())
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=auth-code-exchange-failure")
 		return
 	}
 
-	githubUser, err := l.githubServie.GetUser(c, githubToken)
+	githubUser, err := l.githubService.GetUser(c, githubToken)
 	if err != nil {
 		logrus.Errorf("retrieving user from github failed: %s", err.Error())
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=user-retrieval-failure")
 		return
 	}
 
@@ -74,13 +75,13 @@ func (l *LoginHandler) GithubOAuth2Callback(c *gin.Context) {
 	dbUser, err := l.userService.GetByEmail(*githubUser.Email)
 	if err != nil {
 		logrus.Errorf("retrieving user from db using email failed: %s", err.Error())
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=internal-error")
 		return
 	}
 	githubTokenJSON, err := json.Marshal(githubToken)
 	if err != nil {
 		logrus.Errorf("converting github token to json failed: %s", err.Error())
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=internal-error")
 		return
 	}
 	mapUserAttributes(&dbUser, string(githubTokenJSON), githubUser)
@@ -89,26 +90,44 @@ func (l *LoginHandler) GithubOAuth2Callback(c *gin.Context) {
 	userID, err := l.userService.Save(dbUser)
 	if err != nil {
 		logrus.Errorf("saving user to db failed: %s", err.Error())
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=internal-error")
 		return
 	}
 
 	appToken, err := l.authService.GenerateToken(userID)
 	if err != nil {
 		logrus.Errorf("token generation failed: %s", err.Error())
-		c.Redirect(http.StatusTemporaryRedirect, failRedirectPath)
+		c.Redirect(http.StatusTemporaryRedirect, l.clientURL+"/login?success=false&error=internal-error")
 		return
 	}
 
-	// for security reasons, avoid using cookies to send the token to client
-	// instead use html with a script that stores the token to localstorage and redirects to homepage
-	// this is the only workaround to send token to client without using cookies
-	// since the client(frontend) can only read headers/response with ajax request, and this call is not ajax
-	c.Header("Content-Type", "text/html")
-	c.String(200, `<!DOCTYPE html><html><body><script>(function(){localStorage.setItem("token","%s");location.replace("/");}());</script></body></html>`, appToken)
+	// set the token cookie
+	c.SetSameSite(http.SameSiteNoneMode)
+	c.SetCookie("token", appToken, 60, "/", c.Request.URL.Hostname(), true, true)
+
+	// redirect to client
+	c.Redirect(http.StatusFound, l.clientURL+"/login?success=true")
+
 	logrus.Info("github oauth2 callback finished")
 }
 
+// TokenPayload reads the token from request cookie and sends it as response payload.
+// The idea is to use header based jwt auth (instead of cookie based auth) to avoid any security issues.
+func (l *LoginHandler) TokenPayload(c *gin.Context) {
+	token, err := c.Cookie("token")
+	if err != nil {
+		c.AbortWithStatus(http.StatusUnauthorized)
+		return
+	}
+
+	// delete the token cookie
+	c.SetSameSite(http.SameSiteNoneMode)
+	c.SetCookie("token", "", 0, "/", c.Request.URL.Hostname(), true, true)
+
+	// send token as response payload
+	c.String(http.StatusOK, token)
+}
+
 func mapUserAttributes(dbUser *user.User, ghToken string, githubUser gh.User) {
 	dbUser.GithubToken = ghToken
 	dbUser.Email = githubUser.GetEmail()
diff --git a/internal/httpservice/login_handler_test.go b/internal/httpservice/login_handler_test.go
index 67563bc..0e6603a 100644
--- a/internal/httpservice/login_handler_test.go
+++ b/internal/httpservice/login_handler_test.go
@@ -27,6 +27,7 @@ const (
 	name            = "John Doe"
 	location        = "New York"
 	avatarURL       = "http://example.com/avatar"
+	clientURL       = "http://example.com/ui"
 	oauth2TokenJSON = `{"access_token":"gho_token","token_type":"bearer","expiry":"0001-01-01T00:00:00Z"}`
 )
 
@@ -40,7 +41,7 @@ func TestGithubLogin(t *testing.T) {
 
 		gin.SetMode(gin.TestMode)
 		router := gin.Default()
-		handler := NewLoginHandler(nil, githubService, nil)
+		handler := NewLoginHandler(nil, githubService, nil, "")
 		githubService.EXPECT().GetAuthCodeURL(gomock.Any()).Return("/")
 
 		router.GET("/api/v1/oauth2/login/github", handler.GithubLogin)
@@ -53,7 +54,7 @@ func TestGithubLogin(t *testing.T) {
 }
 
 func TestGithubOAuth2Callback(t *testing.T) {
-	t.Run("should save the user(with token) & return token response when callback invoked", func(t *testing.T) {
+	t.Run("should save the user(with token) & redirect with token cookie when callback invoked", func(t *testing.T) {
 		ctrl := gomock.NewController(t)
 		defer ctrl.Finish()
 		authService := auth.NewMockService(ctrl)
@@ -69,7 +70,7 @@ func TestGithubOAuth2Callback(t *testing.T) {
 
 		gin.SetMode(gin.TestMode)
 		router := gin.Default()
-		handler := NewLoginHandler(authService, githubService, userService)
+		handler := NewLoginHandler(authService, githubService, userService, clientURL)
 		githubService.EXPECT().GetToken(gomock.Any(), authCode).Return(oauthToken, nil)
 		githubService.EXPECT().GetUser(gomock.Any(), oauthToken).Return(githubUser, nil)
 		userService.EXPECT().GetByEmail(email).Return(dbUser, nil)
@@ -88,9 +89,53 @@ func TestGithubOAuth2Callback(t *testing.T) {
 		req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("/oauth2/github/callback?code=%s&state=%s", authCode, state), nil)
 		req.AddCookie(&cookie)
 
+		router.ServeHTTP(response, req)
+		assert.Equal(t, http.StatusFound, response.Code)
+		assert.Contains(t, response.Header().Get("Set-Cookie"), appToken)
+		assert.Equal(t, clientURL+"/login?success=true", response.Header().Get("Location"))
+	})
+}
+
+func TestTokenPayload(t *testing.T) {
+	t.Run("should return token in response payload when request contains token cookie", func(t *testing.T) {
+		ctrl := gomock.NewController(t)
+		defer ctrl.Finish()
+
+		gin.SetMode(gin.TestMode)
+		router := gin.Default()
+		handler := NewLoginHandler(nil, nil, nil, "")
+
+		router.GET("/auth/token", handler.TokenPayload)
+		response := httptest.NewRecorder()
+		cookie := http.Cookie{
+			Name:     "token",
+			Value:    "test-token",
+			Path:     "/",
+			Expires:  time.Now().Add(1 * time.Minute),
+			HttpOnly: true,
+		}
+		req, _ := http.NewRequest(http.MethodGet, "/auth/token", nil)
+		req.AddCookie(&cookie)
+
 		router.ServeHTTP(response, req)
 		assert.Equal(t, http.StatusOK, response.Code)
-		assert.Contains(t, response.Body.String(), appToken)
+		assert.Equal(t, "test-token", response.Body.String())
+	})
+
+	t.Run("should return unauthorized error response when request doesn't contains token cookie", func(t *testing.T) {
+		ctrl := gomock.NewController(t)
+		defer ctrl.Finish()
+
+		gin.SetMode(gin.TestMode)
+		router := gin.Default()
+		handler := NewLoginHandler(nil, nil, nil, "")
+
+		router.GET("/auth/token", handler.TokenPayload)
+		response := httptest.NewRecorder()
+		req, _ := http.NewRequest(http.MethodGet, "/auth/token", nil)
+
+		router.ServeHTTP(response, req)
+		assert.Equal(t, http.StatusUnauthorized, response.Code)
 	})
 }
 
diff --git a/internal/httpservice/router.go b/internal/httpservice/router.go
index 9afa2a2..4b37f8a 100644
--- a/internal/httpservice/router.go
+++ b/internal/httpservice/router.go
@@ -1,12 +1,16 @@
 package httpservice
 
 import (
+	"fmt"
 	"net"
 	"net/http"
+	"net/url"
 	"time"
 
+	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
 	"github.com/batnoter/batnoter-api/internal/applicationconfig"
+	"github.com/sirupsen/logrus"
 )
 
 // Run starts the http server.
@@ -19,8 +23,12 @@ func Run(applicationconfig *applicationconfig.ApplicationConfig) error {
 	router := gin.Default()
 	router.UseRawPath = true
 
+	clientBaseURL := baseURL(applicationconfig.Config.App.ClientURL)
+	router.Use(cors.New(corsConfig(clientBaseURL)))
+	logrus.Infof("allowing cors for %s", clientBaseURL)
+
 	noteHandler := NewNoteHandler(applicationconfig.GithubService, applicationconfig.UserService)
-	loginHandler := NewLoginHandler(applicationconfig.AuthService, applicationconfig.GithubService, applicationconfig.UserService)
+	loginHandler := NewLoginHandler(applicationconfig.AuthService, applicationconfig.GithubService, applicationconfig.UserService, applicationconfig.Config.App.ClientURL)
 	userHandler := NewUserHandler(applicationconfig.UserService)
 	preferenceHandler := NewPreferenceHandler(applicationconfig.PreferenceService, applicationconfig.GithubService, applicationconfig.UserService)
 	authMiddleware := NewMiddleware(applicationconfig.AuthService)
@@ -38,6 +46,7 @@ func Run(applicationconfig *applicationconfig.ApplicationConfig) error {
 	v1.POST("/notes/:path", authMiddleware.AuthorizeToken(), noteHandler.SaveNote)     // create/update single note
 	v1.DELETE("/notes/:path", authMiddleware.AuthorizeToken(), noteHandler.DeleteNote) // delete single note
 
+	v1.GET("/auth/token", loginHandler.TokenPayload)
 	v1.GET("/oauth2/login/github", loginHandler.GithubLogin)
 	v1.GET("/oauth2/github/callback", loginHandler.GithubOAuth2Callback)
 
@@ -51,3 +60,26 @@ func Run(applicationconfig *applicationconfig.ApplicationConfig) error {
 	}
 	return server.ListenAndServe()
 }
+
+func corsConfig(clientBaseURL string) cors.Config {
+	return cors.Config{
+		AllowMethods:     []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"},
+		AllowHeaders:     []string{"Origin", "Authorization", "Content-Length", "Content-Type"},
+		AllowCredentials: true,
+		AllowOriginFunc: func(origin string) bool {
+			return origin == clientBaseURL
+		},
+		MaxAge: 12 * time.Hour,
+	}
+}
+
+func baseURL(clientURL string) string {
+	if clientURL == "" {
+		logrus.Fatal("client url is not configured")
+	}
+	u, err := url.Parse(clientURL)
+	if err != nil {
+		logrus.WithField("client-url", clientURL).Fatal("invalid client url")
+	}
+	return fmt.Sprintf(`%s://%s`, u.Scheme, u.Host)
+}