Skip to content

Commit

Permalink
feat: Add option for propagating id_token to upstream app
Browse files Browse the repository at this point in the history
Fixes #315

Co-authored-by: tronghn <trong.huu.nguyen@nav.no>
  • Loading branch information
sindrerh2 and tronghn committed Jan 20, 2025
1 parent bc30791 commit 2feb6a3
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 66 deletions.
93 changes: 47 additions & 46 deletions docs/configuration.md

Large diffs are not rendered by default.

37 changes: 20 additions & 17 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ type Config struct {
ShutdownWaitBeforePeriod time.Duration `json:"shutdown-wait-before-period"`
Version string `json:"version"`

AutoLogin bool `json:"auto-login"`
AutoLoginIgnorePaths []string `json:"auto-login-ignore-paths"`
Cookie Cookie `json:"cookie"`
EncryptionKey string `json:"encryption-key"`
Ingresses []string `json:"ingress"`
LegacyCookie bool `json:"legacy-cookie"`
UpstreamAccessLogs bool `json:"upstream-access-logs"`
UpstreamHost string `json:"upstream-host"`
UpstreamIP string `json:"upstream-ip"`
UpstreamPort int `json:"upstream-port"`
AutoLogin bool `json:"auto-login"`
AutoLoginIgnorePaths []string `json:"auto-login-ignore-paths"`
Cookie Cookie `json:"cookie"`
EncryptionKey string `json:"encryption-key"`
Ingresses []string `json:"ingress"`
LegacyCookie bool `json:"legacy-cookie"`
UpstreamAccessLogs bool `json:"upstream-access-logs"`
UpstreamHost string `json:"upstream-host"`
UpstreamIP string `json:"upstream-ip"`
UpstreamPort int `json:"upstream-port"`
UpstreamIncludeIdToken bool `json:"upstream-include-id-token"`

OpenTelemetry OpenTelemetry `json:"otel"`
OpenID OpenID `json:"openid"`
Expand All @@ -50,13 +51,14 @@ const (
ShutdownGracefulPeriod = "shutdown-graceful-period"
ShutdownWaitBeforePeriod = "shutdown-wait-before-period"

AutoLogin = "auto-login"
AutoLoginIgnorePaths = "auto-login-ignore-paths"
Ingress = "ingress"
UpstreamAccessLogs = "upstream-access-logs"
UpstreamHost = "upstream-host"
UpstreamIP = "upstream-ip"
UpstreamPort = "upstream-port"
AutoLogin = "auto-login"
AutoLoginIgnorePaths = "auto-login-ignore-paths"
Ingress = "ingress"
UpstreamAccessLogs = "upstream-access-logs"
UpstreamHost = "upstream-host"
UpstreamIP = "upstream-ip"
UpstreamPort = "upstream-port"
UpstreamIncludeIdToken = "upstream-include-id-token"
)

var logger = log.WithField("logger", "wonderwall.config")
Expand All @@ -78,6 +80,7 @@ func Initialize() (*Config, error) {
flag.String(UpstreamHost, "127.0.0.1:8080", "Address of upstream host.")
flag.String(UpstreamIP, "", "IP of upstream host. Overrides 'upstream-host' if set.")
flag.Int(UpstreamPort, 0, "Port of upstream host. Overrides 'upstream-host' if set.")
flag.Bool(UpstreamIncludeIdToken, false, "Include ID token in upstream requests in 'X-Wonderwall-Id-Token' header.")

cookieFlags()
openidFlags()
Expand Down
2 changes: 1 addition & 1 deletion pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func NewStandalone(
Ingresses: ingresses,
Redirect: url.NewStandaloneRedirect(),
SessionManager: sessionManager,
UpstreamProxy: NewUpstreamProxy(upstream, cfg.UpstreamAccessLogs),
UpstreamProxy: NewUpstreamProxy(upstream, cfg.UpstreamAccessLogs, cfg.UpstreamIncludeIdToken),
}, nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/handler/handler_sso_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewSSOProxy(cfg *config.Config, crypter crypto.Crypter) (*SSOProxy, error)
SSOServerURL: serverURL,
SSOServerReverseProxy: NewReverseProxy(serverURL, false),
SessionReader: sessionReader,
UpstreamProxy: NewUpstreamProxy(upstream, cfg.UpstreamAccessLogs),
UpstreamProxy: NewUpstreamProxy(upstream, cfg.UpstreamAccessLogs, cfg.UpstreamIncludeIdToken),
}, nil
}

Expand Down
13 changes: 12 additions & 1 deletion pkg/handler/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ type ReverseProxySource interface {
type ReverseProxy struct {
*httputil.ReverseProxy
EnableAccessLogs bool
IncludeIdToken bool
}

func NewUpstreamProxy(upstream *urllib.URL, enableAccessLogs bool) *ReverseProxy {
func NewUpstreamProxy(upstream *urllib.URL, enableAccessLogs bool, includeIdToken bool) *ReverseProxy {
rp := NewReverseProxy(upstream, true)
rp.EnableAccessLogs = enableAccessLogs
rp.IncludeIdToken = includeIdToken
return rp
}

Expand Down Expand Up @@ -68,6 +70,11 @@ func NewReverseProxy(upstream *urllib.URL, preserveInboundHostHeader bool) *Reve
if ok {
r.Out.Header.Set("authorization", "Bearer "+accessToken)
}

idToken, ok := mw.IdTokenFrom(r.In.Context())
if ok {
r.Out.Header.Set("X-Wonderwall-Id-Token", idToken)
}
},
Transport: server.DefaultTransport(),
}
Expand Down Expand Up @@ -117,6 +124,10 @@ func (rp *ReverseProxy) Handler(src ReverseProxySource, w http.ResponseWriter, r

if isAuthenticated {
ctx = mw.WithAccessToken(ctx, accessToken)
if rp.IncludeIdToken {
idToken := sess.IDToken()
ctx = mw.WithIdToken(ctx, idToken)
}

if rp.EnableAccessLogs && isRelevantAccessLog(r) {
logger.Info("default: authenticated request")
Expand Down
41 changes: 41 additions & 0 deletions pkg/handler/reverseproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,4 +437,45 @@ func TestReverseProxy(t *testing.T) {
}...)
assertUpstreamOKResponse(t, resp)
})

t.Run("request should not include idToken by default", func(t *testing.T) {
cfg := mock.Config()
cfg.UpstreamHost = up.URL.Host
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()

up.SetIdentityProvider(idp)
rpClient := idp.RelyingPartyClient()

// acquire session
login(t, rpClient, idp)

up.requestCallback = func(r *http.Request) {
assert.Empty(t, r.Header.Get("x-wonderwall-id-token"))
}

resp := get(t, rpClient, idp.RelyingPartyServer.URL)
assertUpstreamOKResponse(t, resp)
})

t.Run("request should include idToken", func(t *testing.T) {
cfg := mock.Config()
cfg.UpstreamHost = up.URL.Host
cfg.UpstreamIncludeIdToken = true
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()

up.SetIdentityProvider(idp)
rpClient := idp.RelyingPartyClient()

// acquire session
login(t, rpClient, idp)

up.requestCallback = func(r *http.Request) {
assert.NotEmpty(t, r.Header.Get("x-wonderwall-id-token"))
}

resp := get(t, rpClient, idp.RelyingPartyServer.URL)
assertUpstreamOKResponse(t, resp)
})
}
10 changes: 10 additions & 0 deletions pkg/middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type contextKey string

const (
ctxAccessToken = contextKey("AccessToken")
ctxIdToken = contextKey("IdToken")
ctxIngress = contextKey("Ingress")
ctxPath = contextKey("Path")
)
Expand All @@ -24,6 +25,15 @@ func WithAccessToken(ctx context.Context, accessToken string) context.Context {
return context.WithValue(ctx, ctxAccessToken, accessToken)
}

func IdTokenFrom(ctx context.Context) (string, bool) {
idToken, ok := ctx.Value(ctxIdToken).(string)
return idToken, ok
}

func WithIdToken(ctx context.Context, idToken string) context.Context {
return context.WithValue(ctx, ctxIdToken, idToken)
}

func IngressFrom(ctx context.Context) (ingress.Ingress, bool) {
i, ok := ctx.Value(ctxIngress).(ingress.Ingress)
return i, ok
Expand Down

0 comments on commit 2feb6a3

Please sign in to comment.