Skip to content

Commit

Permalink
check access token for group attribute as well; fix #43
Browse files Browse the repository at this point in the history
  • Loading branch information
hahahannes committed Oct 24, 2024
1 parent bedb8dc commit 12e057c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 2 deletions.
5 changes: 5 additions & 0 deletions mlflow_oidc_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class AppConfig:
OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None)
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None)
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None)
OIDC_AUDIENCE = os.environ.get("OIDC_AUDIENCE", None)
OIDC_PUBLIC_KEYS_URL = os.environ.get("OIDC_PUBLIC_KEYS_URL", None)
OIDC_USERNAME_TOKEN_ATTRIBUTE = os.environ.get("OIDC_USERNAME_TOKEN_ATTRIBUTE", "email")
OIDC_SIGNING_ALG = os.environ.get("OIDC_SIGNING_ALG", None) # ES256, EdDSA, PS256, RS256, HS256
OIDC_HS256_SECRET = os.environ.get("OIDC_HS256_SECRET", None) # ES256, EdDSA, PS256, RS256, HS256

# https://flask-session.readthedocs.io/en/latest/config.html
SESSION_TYPE = os.environ.get("SESSION_TYPE", "filesystem")
Expand Down
70 changes: 68 additions & 2 deletions mlflow_oidc_auth/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import re
import requests
Expand Down Expand Up @@ -96,6 +97,8 @@

from mlflow.server import app

import jwt

# Create the OAuth2 client
auth_client = WebApplicationClient(AppConfig.get_property("OIDC_CLIENT_ID"))
store = SqlAlchemyStore()
Expand Down Expand Up @@ -186,6 +189,9 @@ def _get_permission_from_store_or_default(

def authenticate_request_basic_auth() -> Union[Authorization, Response]:
username = request.authorization.username
if username == "" or username is None:
app.logger.debug("Username is not set in basic auth")
return False
password = request.authorization.password
app.logger.debug("Authenticating user %s", username)
if store.authenticate_user(username.lower(), password):
Expand All @@ -195,6 +201,61 @@ def authenticate_request_basic_auth() -> Union[Authorization, Response]:
else:
app.logger.debug("User %s not authenticated", username)
return False


def _get_public_keys():
"""
Returns:
List of RSA public keys usable by PyJWT.
"""
r = requests.get(AppConfig.get_property("OIDC_PUBLIC_KEYS_URL"))
public_keys = []
jwk_set = r.json()
for key_dict in jwk_set["keys"]:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key_dict))
public_keys.append(public_key)
return public_keys


def validate_token(token, key, sign_alg):
try:
token = jwt.decode(token, key=key, audience=AppConfig.get_property("OIDC_AUDIENCE"), algorithms=[sign_alg])
except jwt.exceptions.InvalidTokenError as e:
app.logger.debug(f"Token is not valid: {token}")
raise MlflowException(f"Token is not valid: {str(e)}")
username_token_attr = AppConfig.get_property("OIDC_USERNAME_TOKEN_ATTRIBUTE")
username = token[username_token_attr]
if username == "" or username is None:
app.logger.debug(f"username from token attribute {username_token_attr} is {username}")
raise MlflowException(f"Username is not set at attribute: {username_token_attr}")
_set_username(username)


def authenticate_token():
"""
Verify the token in the request.
"""
token = request.authorization.token
sign_alg = AppConfig.get_property("OIDC_SIGNING_ALG")
token_is_valid = False
if sign_alg == "HS256":
key = AppConfig.get_property("OIDC_HS256_SECRET")
try:
validate_token(token, key, sign_alg)
token_is_valid = True
except MlflowException as e:
return False

keys = _get_public_keys()
for key in keys:
try:
validate_token(token, key, sign_alg)
token_is_valid = True
break
except MlflowException as e:
return False

return token_is_valid


def _get_username():
Expand Down Expand Up @@ -617,8 +678,13 @@ def before_request_hook():
if _is_unprotected_route(request.path):
return
if request.authorization is not None:
if not authenticate_request_basic_auth():
return make_basic_auth_response()
if not authenticate_token():
app.logger.debug("No valid token authentication found")
# TODO maybe return 401 here instead of basic auth response

if not authenticate_request_basic_auth():
app.logger.debug("No valid basic authentication found")
return make_basic_auth_response()
else:
# authentication
if not _get_username():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"Flask-Session>=0.7.0",
"gunicorn<24; platform_system != 'Windows'",
"alembic<2,!=1.10.0",
"pyjwt[crypto]>=2.9.0,<3.0.0"
]

[project.optional-dependencies]
Expand Down

0 comments on commit 12e057c

Please sign in to comment.