Skip to content

Commit

Permalink
Merge branch 'main' into fix/app-render-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
hexart authored Jan 22, 2025
2 parents e9af9aa + 85c98ea commit b25ffec
Show file tree
Hide file tree
Showing 14 changed files with 377 additions and 122 deletions.
2 changes: 1 addition & 1 deletion .github/actions/pnpm-node-install/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ inputs:
node-version:
description: Node.js version
required: true
default: '22.7.0'
default: '23.3.0'
pnpm-version:
description: pnpm version
required: true
Expand Down
16 changes: 14 additions & 2 deletions backend/chainlit/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from chainlit.logger import logger
from chainlit.oauth_providers import get_configured_oauth_providers

from .cookie import OAuth2PasswordBearerWithCookie
from .cookie import (
OAuth2PasswordBearerWithCookie,
clear_auth_cookie,
get_token_from_cookies,
set_auth_cookie,
)
from .jwt import create_jwt, decode_jwt, get_jwt_secret

reuseable_oauth = OAuth2PasswordBearerWithCookie(tokenUrl="/login", auto_error=False)
Expand Down Expand Up @@ -80,4 +85,11 @@ async def get_current_user(token: str = Depends(reuseable_oauth)):
return await authenticate_user(token)


__all__ = ["create_jwt", "get_configuration", "get_current_user"]
__all__ = [
"clear_auth_cookie",
"create_jwt",
"get_configuration",
"get_current_user",
"get_token_from_cookies",
"set_auth_cookie",
]
98 changes: 86 additions & 12 deletions backend/chainlit/auth/cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(

async def __call__(self, request: Request) -> Optional[str]:
# First try to get the token from the cookie
token = request.cookies.get(_auth_cookie_name)
token = get_token_from_cookies(request.cookies)

# If no cookie, try the Authorization header as fallback
if not token:
Expand Down Expand Up @@ -76,26 +76,100 @@ async def __call__(self, request: Request) -> Optional[str]:
return token


def set_auth_cookie(response: Response, token: str):
def _get_chunked_cookie(cookies: dict[str, str], name: str) -> Optional[str]:
# Gather all auth_chunk_i cookies, sorted by their index
chunk_parts = []

i = 0
while True:
cookie_key = f"{_auth_cookie_name}_{i}"
if cookie_key not in cookies:
break

print("Reading chunk", cookie_key)
chunk_parts.append(cookies[cookie_key])
i += 1

joined = "".join(chunk_parts)

return joined if joined != "" else None


def get_token_from_cookies(cookies: dict[str, str]) -> Optional[str]:
"""
Read all chunk cookies and reconstruct the token
"""

print("Found cookies", cookies.keys())

# Default/unchunked cookies
if value := cookies.get(_auth_cookie_name):
print("Returning unchunked", _auth_cookie_name, value)
return value

return _get_chunked_cookie(cookies, _auth_cookie_name)


def set_auth_cookie(request: Request, response: Response, token: str):
"""
Helper function to set the authentication cookie with secure parameters
and remove any leftover chunks from a previously larger token.
"""

response.set_cookie(
key=_auth_cookie_name,
value=token,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)
_chunk_size = 3000

existing_cookies = {
k for k in request.cookies.keys() if k.startswith(_auth_cookie_name)
}

if len(token) > _chunk_size:
chunks = [token[i : i + _chunk_size] for i in range(0, len(token), _chunk_size)]

for i, chunk in enumerate(chunks):
k = f"{_auth_cookie_name}_{i}"

print("Setting", k)

def clear_auth_cookie(response: Response):
response.set_cookie(
key=k,
value=chunk,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)

existing_cookies.discard(k)
else:
# Default (shorter cookies)
response.set_cookie(
key=_auth_cookie_name,
value=token,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)

existing_cookies.discard(_auth_cookie_name)

# Delete remaining prior cookies/cookie chunks
for k in existing_cookies:
print("Deleting", k)
response.delete_cookie(key=k, path="/")


def clear_auth_cookie(request: Request, response: Response):
"""
Helper function to clear the authentication cookie
"""
response.delete_cookie(key=_auth_cookie_name, path="/")

existing_cookies = {
k for k in request.cookies.keys() if k.startswith(_auth_cookie_name)
}

for k in existing_cookies:
response.delete_cookie(key=k, path="/")


def set_oauth_state_cookie(response: Response, token: str):
Expand Down
22 changes: 13 additions & 9 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _get_oauth_redirect_error(error: str) -> Response:


async def _authenticate_user(
user: Optional[User], redirect_to_callback: bool = False
request: Request, user: Optional[User], redirect_to_callback: bool = False
) -> Response:
"""Authenticate a user and return the response."""

Expand All @@ -478,13 +478,17 @@ async def _authenticate_user(

response = _get_auth_response(access_token, redirect_to_callback)

set_auth_cookie(response, access_token)
set_auth_cookie(request, response, access_token)

return response


@router.post("/login")
async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
async def login(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""
Login a user using the password auth callback.
"""
Expand All @@ -497,13 +501,13 @@ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depen
form_data.username, form_data.password
)

return await _authenticate_user(user)
return await _authenticate_user(request, user)


@router.post("/logout")
async def logout(request: Request, response: Response):
"""Logout the user by calling the on_logout callback."""
clear_auth_cookie(response)
clear_auth_cookie(request, response)

if config.code.on_logout:
return await config.code.on_logout(request, response)
Expand Down Expand Up @@ -535,7 +539,7 @@ async def jwt_auth(request: Request):

try:
user = decode_jwt(token)
return await _authenticate_user(user)
return await _authenticate_user(request, user)
except InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")

Expand All @@ -551,7 +555,7 @@ async def header_auth(request: Request):

user = await config.code.header_auth_callback(request.headers)

return await _authenticate_user(user)
return await _authenticate_user(request, user)


@router.get("/auth/oauth/{provider_id}")
Expand Down Expand Up @@ -640,7 +644,7 @@ async def oauth_callback(
provider_id, token, raw_user_data, default_user
)

response = await _authenticate_user(user, redirect_to_callback=True)
response = await _authenticate_user(request, user, redirect_to_callback=True)

clear_oauth_state_cookie(response)

Expand Down Expand Up @@ -689,7 +693,7 @@ async def oauth_azure_hf_callback(
provider_id, token, raw_user_data, default_user, id_token
)

response = await _authenticate_user(user, redirect_to_callback=True)
response = await _authenticate_user(request, user, redirect_to_callback=True)

clear_oauth_state_cookie(response)

Expand Down
8 changes: 6 additions & 2 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from starlette.requests import cookie_parser
from typing_extensions import TypeAlias

from chainlit.auth import get_current_user, require_login
from chainlit.auth import (
get_current_user,
get_token_from_cookies,
require_login,
)
from chainlit.chat_context import chat_context
from chainlit.config import config
from chainlit.context import init_ws_context
Expand Down Expand Up @@ -83,7 +87,7 @@ def load_user_env(user_env):
def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]:
if cookie_header := environ.get("HTTP_COOKIE", None):
cookies = cookie_parser(cookie_header)
return cookies.get("access_token", None)
return get_token_from_cookies(cookies)

return None

Expand Down
Empty file added backend/tests/auth/__init__.py
Empty file.
Loading

0 comments on commit b25ffec

Please sign in to comment.