Skip to content

Commit

Permalink
Add changes for authenticating requests with bearer token
Browse files Browse the repository at this point in the history
  • Loading branch information
cant-code committed Apr 13, 2024
1 parent af386d8 commit c4c853f
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 4 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.26.3
github.com/aws/aws-sdk-go-v2/service/s3 v1.48.0
github.com/go-stomp/stomp/v3 v3.0.5
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/golang-migrate/migrate/v4 v4.17.0
github.com/spf13/viper v1.18.2
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ github.com/go-stomp/stomp/v3 v3.0.5 h1:yOORvXLqSu0qF4loJjfWrcVE1o0+9cFudclcP0an3
github.com/go-stomp/stomp/v3 v3.0.5/go.mod h1:ztzZej6T2W4Y6FlD+Tb5n7HQP3/O5UNQiuC169pIp10=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-migrate/migrate/v4 v4.17.0 h1:rd40H3QXU0AA4IoLllFcEAEo9dYKRHYND2gB4p7xcaU=
github.com/golang-migrate/migrate/v4 v4.17.0/go.mod h1:+Cp2mtLP4/aXDTKb9wmXYitdrNx2HGs45rbWAo6OsKM=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
Expand Down
15 changes: 15 additions & 0 deletions internal/auth/authMiddleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package auth

import (
"log"
"net/http"
)

func HandleJwtAuthMiddleware() func(http.Handler) http.Handler {
set, err := getJWKSet("http://localhost:8900/realms/yt-clone/protocol/openid-connect/certs")
if err != nil {
log.Printf("Error fetching jwk-sets: %v\n", err)
}

return jwtMiddleware(set)
}
90 changes: 90 additions & 0 deletions internal/auth/jwkFetcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package auth

import (
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"math/big"
"net/http"
"strings"
)

func getJWKSet(url string) (map[string]*rsa.PublicKey, error) {
// Make the GET request
response, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("error making GET request: %v", err)
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
log.Println("Error closing body:", err)
}
}(response.Body)

// Decode the JSON response
var jwkSet struct {
Keys []struct {
Kid string `json:"kid"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
X5C []string `json:"x5c"`
} `json:"keys"`
}
decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&jwkSet); err != nil {
return nil, fmt.Errorf("error decoding JSON: %v", err)
}

// Create a map to store RSA public keys
jwkMap := make(map[string]*rsa.PublicKey)

// Iterate through each key in the JWK set
for _, key := range jwkSet.Keys {
// Decode base64url-encoded modulus (N) and exponent (E)
modulus, err := decodeBase64URL(key.N)
if err != nil {
return nil, fmt.Errorf("error decoding modulus: %v", err)
}

exponent, err := decodeBase64URL(key.E)
if err != nil {
return nil, fmt.Errorf("error decoding exponent: %v", err)
}

// Create RSA public key
pubKey := &rsa.PublicKey{
N: modulus,
E: int(exponent.Int64()),
}

// Store the public key in the map using the key ID (Kid)
jwkMap[key.Alg] = pubKey
}

return jwkMap, nil
}

func decodeBase64URL(input string) (*big.Int, error) {
base64Str := strings.ReplaceAll(input, "-", "+")
base64Str = strings.ReplaceAll(base64Str, "_", "/")

switch len(base64Str) % 4 {
case 2:
base64Str += "=="
case 3:
base64Str += "="
}

data, err := base64.StdEncoding.DecodeString(base64Str)
if err != nil {
return nil, err
}

result := new(big.Int).SetBytes(data)
return result, nil
}
55 changes: 55 additions & 0 deletions internal/auth/tokenHandler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package auth

import (
"crypto/rsa"
"fmt"
"github.com/golang-jwt/jwt/v5"
"log"
"net/http"
"strings"
)

func jwtMiddleware(jwkSet map[string]*rsa.PublicKey) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
http.Error(w, "Authorization header is required", http.StatusUnauthorized)
return
}

tokenString := strings.TrimPrefix(authHeader, "Bearer ")
token, err := jwt.Parse(tokenString, parseToken(jwkSet))

if err != nil || !token.Valid {
log.Println("error validating token:", err)
http.Error(w, "", http.StatusUnauthorized)
return
}

issuer, err := token.Claims.GetIssuer()
if err != nil || issuer != "http://localhost:8900/realms/yt-clone" {
log.Println("error validating issuer:", err)
http.Error(w, "", http.StatusUnauthorized)
return
}

next.ServeHTTP(w, r)
})
}
}

func parseToken(jwkSet map[string]*rsa.PublicKey) func(token *jwt.Token) (interface{}, error) {
return func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

alg := token.Method.Alg()
publicKey, ok := jwkSet[alg]
if !ok {
return nil, fmt.Errorf("no key found for signing method: %v", alg)
}
return publicKey, nil
}
}
7 changes: 5 additions & 2 deletions internal/handlers/apiHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@ package handlers
import (
"database/sql"
"net/http"
"yt-clone-video-processing/internal/auth"
)

type Dependencies struct {
DBConn *sql.DB
}

func (dependencies *Dependencies) ApiHandler() *http.ServeMux {
func (apiConfig *Dependencies) ApiHandler() *http.ServeMux {
mux := http.NewServeMux()

mux.HandleFunc("GET /videos/errors/{id}", dependencies.errorHandler)
handler := auth.HandleJwtAuthMiddleware()

mux.Handle("GET /videos/errors/{id}", handler(http.HandlerFunc(apiConfig.errorHandler)))

return mux
}
4 changes: 2 additions & 2 deletions internal/handlers/errorHandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ const (
applicationJson = "application/json"
)

func (dependencies *Dependencies) errorHandler(w http.ResponseWriter, r *http.Request) {
exec, err := dependencies.DBConn.Query(selectErrorsUsingVid, r.PathValue(id))
func (apiConfig *Dependencies) errorHandler(w http.ResponseWriter, r *http.Request) {
exec, err := apiConfig.DBConn.Query(selectErrorsUsingVid, r.PathValue(id))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_, err := fmt.Fprintf(w, err.Error())
Expand Down

0 comments on commit c4c853f

Please sign in to comment.