Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: token authentication for api access; fix #44 #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlflow_oidc_auth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class AppConfig:
OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None)
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None)
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None)
OIDC_AUDIENCE = os.environ.get("OIDC_AUDIENCE", None)
OIDC_PUBLIC_KEYS_URL = os.environ.get("OIDC_PUBLIC_KEYS_URL", None)
OIDC_USERNAME_TOKEN_ATTRIBUTE = os.environ.get("OIDC_USERNAME_TOKEN_ATTRIBUTE", "email")
OIDC_SIGNING_ALG = os.environ.get("OIDC_SIGNING_ALG", None) # ES256, EdDSA, PS256, RS256, HS256
OIDC_HS256_SECRET = os.environ.get("OIDC_HS256_SECRET", None) # ES256, EdDSA, PS256, RS256, HS256

# https://flask-session.readthedocs.io/en/latest/config.html
SESSION_TYPE = os.environ.get("SESSION_TYPE", "filesystem")
Expand Down
73 changes: 71 additions & 2 deletions mlflow_oidc_auth/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import re
import requests
Expand Down Expand Up @@ -96,6 +97,8 @@

from mlflow.server import app

import jwt

# Create the OAuth2 client
auth_client = WebApplicationClient(AppConfig.get_property("OIDC_CLIENT_ID"))
store = SqlAlchemyStore()
Expand Down Expand Up @@ -186,6 +189,9 @@ def _get_permission_from_store_or_default(

def authenticate_request_basic_auth() -> Union[Authorization, Response]:
username = request.authorization.username
if username == "" or username is None:
app.logger.debug("Username is not set in basic auth")
return False
password = request.authorization.password
app.logger.debug("Authenticating user %s", username)
if store.authenticate_user(username.lower(), password):
Expand All @@ -195,6 +201,64 @@ def authenticate_request_basic_auth() -> Union[Authorization, Response]:
else:
app.logger.debug("User %s not authenticated", username)
return False


def _get_public_keys():
"""
Returns:
List of RSA public keys usable by PyJWT.
"""
r = requests.get(AppConfig.get_property("OIDC_PUBLIC_KEYS_URL"))
public_keys = []
jwk_set = r.json()
for key_dict in jwk_set["keys"]:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(key_dict))
public_keys.append(public_key)
return public_keys


def validate_token(token, key, sign_alg):
try:
token = jwt.decode(token, key=key, audience=AppConfig.get_property("OIDC_AUDIENCE"), algorithms=[sign_alg])
except jwt.exceptions.InvalidTokenError as e:
app.logger.debug(f"Token is not valid: {token}")
raise MlflowException(f"Token is not valid: {str(e)}")
username_token_attr = AppConfig.get_property("OIDC_USERNAME_TOKEN_ATTRIBUTE")
username = token[username_token_attr]
if username == "" or username is None:
app.logger.debug(f"username from token attribute {username_token_attr} is {username}")
raise MlflowException(f"Username is not set at attribute: {username_token_attr}")
_set_username(username)


def authenticate_token():
"""
Verify the token in the request.
"""
token = request.authorization.token
if token == "" or token is None:
app.logger.debug(f"Token is not set: {token}")
return False
sign_alg = AppConfig.get_property("OIDC_SIGNING_ALG")
token_is_valid = False
if sign_alg == "HS256":
key = AppConfig.get_property("OIDC_HS256_SECRET")
try:
validate_token(token, key, sign_alg)
token_is_valid = True
except MlflowException as e:
return False

keys = _get_public_keys()
for key in keys:
try:
validate_token(token, key, sign_alg)
token_is_valid = True
break
except MlflowException as e:
return False

return token_is_valid


def _get_username():
Expand Down Expand Up @@ -617,8 +681,13 @@ def before_request_hook():
if _is_unprotected_route(request.path):
return
if request.authorization is not None:
if not authenticate_request_basic_auth():
return make_basic_auth_response()
if not authenticate_token():
app.logger.debug("No valid token authentication found")
# TODO maybe return 401 here instead of basic auth response

if not authenticate_request_basic_auth():
app.logger.debug("No valid basic authentication found")
return make_basic_auth_response()
else:
# authentication
if not _get_username():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"Flask-Session>=0.7.0",
"gunicorn<24; platform_system != 'Windows'",
"alembic<2,!=1.10.0",
"pyjwt[crypto]>=2.9.0,<3.0.0"
]

[project.optional-dependencies]
Expand Down