Skip to content

Commit

Permalink
Chunk long cookies (#1758)
Browse files Browse the repository at this point in the history
* doesn't build locally, commit from anthonym1991

* ruff formatting

* ruff formatting again

* test

* Added tests, this does very little in terms of mocking

* Ruff

* Ruff formatting.

* Move cookie tests to appropriate location.

* test: rewrite cookie tests using FastAPI TestClient for real cookie handling

* fix: update cookie endpoint and tests to use form data instead of JSON

* Fixups cookie stuff.

* fixed

* Tests passing, fallback to unchunked.

* Remove unused import.

---------

Co-authored-by: Mathijs de Bruin <mathijs@mathijsfietst.nl>
  • Loading branch information
jpolvto and dokterbob authored Jan 22, 2025
1 parent f707f8b commit 85c98ea
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 25 deletions.
16 changes: 14 additions & 2 deletions backend/chainlit/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from chainlit.logger import logger
from chainlit.oauth_providers import get_configured_oauth_providers

from .cookie import OAuth2PasswordBearerWithCookie
from .cookie import (
OAuth2PasswordBearerWithCookie,
clear_auth_cookie,
get_token_from_cookies,
set_auth_cookie,
)
from .jwt import create_jwt, decode_jwt, get_jwt_secret

reuseable_oauth = OAuth2PasswordBearerWithCookie(tokenUrl="/login", auto_error=False)
Expand Down Expand Up @@ -80,4 +85,11 @@ async def get_current_user(token: str = Depends(reuseable_oauth)):
return await authenticate_user(token)


__all__ = ["create_jwt", "get_configuration", "get_current_user"]
__all__ = [
"clear_auth_cookie",
"create_jwt",
"get_configuration",
"get_current_user",
"get_token_from_cookies",
"set_auth_cookie",
]
98 changes: 86 additions & 12 deletions backend/chainlit/auth/cookie.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(

async def __call__(self, request: Request) -> Optional[str]:
# First try to get the token from the cookie
token = request.cookies.get(_auth_cookie_name)
token = get_token_from_cookies(request.cookies)

# If no cookie, try the Authorization header as fallback
if not token:
Expand Down Expand Up @@ -76,26 +76,100 @@ async def __call__(self, request: Request) -> Optional[str]:
return token


def set_auth_cookie(response: Response, token: str):
def _get_chunked_cookie(cookies: dict[str, str], name: str) -> Optional[str]:
# Gather all auth_chunk_i cookies, sorted by their index
chunk_parts = []

i = 0
while True:
cookie_key = f"{_auth_cookie_name}_{i}"
if cookie_key not in cookies:
break

print("Reading chunk", cookie_key)
chunk_parts.append(cookies[cookie_key])
i += 1

joined = "".join(chunk_parts)

return joined if joined != "" else None


def get_token_from_cookies(cookies: dict[str, str]) -> Optional[str]:
"""
Read all chunk cookies and reconstruct the token
"""

print("Found cookies", cookies.keys())

# Default/unchunked cookies
if value := cookies.get(_auth_cookie_name):
print("Returning unchunked", _auth_cookie_name, value)
return value

return _get_chunked_cookie(cookies, _auth_cookie_name)


def set_auth_cookie(request: Request, response: Response, token: str):
"""
Helper function to set the authentication cookie with secure parameters
and remove any leftover chunks from a previously larger token.
"""

response.set_cookie(
key=_auth_cookie_name,
value=token,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)
_chunk_size = 3000

existing_cookies = {
k for k in request.cookies.keys() if k.startswith(_auth_cookie_name)
}

if len(token) > _chunk_size:
chunks = [token[i : i + _chunk_size] for i in range(0, len(token), _chunk_size)]

for i, chunk in enumerate(chunks):
k = f"{_auth_cookie_name}_{i}"

print("Setting", k)

def clear_auth_cookie(response: Response):
response.set_cookie(
key=k,
value=chunk,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)

existing_cookies.discard(k)
else:
# Default (shorter cookies)
response.set_cookie(
key=_auth_cookie_name,
value=token,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)

existing_cookies.discard(_auth_cookie_name)

# Delete remaining prior cookies/cookie chunks
for k in existing_cookies:
print("Deleting", k)
response.delete_cookie(key=k, path="/")


def clear_auth_cookie(request: Request, response: Response):
"""
Helper function to clear the authentication cookie
"""
response.delete_cookie(key=_auth_cookie_name, path="/")

existing_cookies = {
k for k in request.cookies.keys() if k.startswith(_auth_cookie_name)
}

for k in existing_cookies:
response.delete_cookie(key=k, path="/")


def set_oauth_state_cookie(response: Response, token: str):
Expand Down
22 changes: 13 additions & 9 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _get_oauth_redirect_error(error: str) -> Response:


async def _authenticate_user(
user: Optional[User], redirect_to_callback: bool = False
request: Request, user: Optional[User], redirect_to_callback: bool = False
) -> Response:
"""Authenticate a user and return the response."""

Expand All @@ -478,13 +478,17 @@ async def _authenticate_user(

response = _get_auth_response(access_token, redirect_to_callback)

set_auth_cookie(response, access_token)
set_auth_cookie(request, response, access_token)

return response


@router.post("/login")
async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depends()):
async def login(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""
Login a user using the password auth callback.
"""
Expand All @@ -497,13 +501,13 @@ async def login(response: Response, form_data: OAuth2PasswordRequestForm = Depen
form_data.username, form_data.password
)

return await _authenticate_user(user)
return await _authenticate_user(request, user)


@router.post("/logout")
async def logout(request: Request, response: Response):
"""Logout the user by calling the on_logout callback."""
clear_auth_cookie(response)
clear_auth_cookie(request, response)

if config.code.on_logout:
return await config.code.on_logout(request, response)
Expand Down Expand Up @@ -535,7 +539,7 @@ async def jwt_auth(request: Request):

try:
user = decode_jwt(token)
return await _authenticate_user(user)
return await _authenticate_user(request, user)
except InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")

Expand All @@ -551,7 +555,7 @@ async def header_auth(request: Request):

user = await config.code.header_auth_callback(request.headers)

return await _authenticate_user(user)
return await _authenticate_user(request, user)


@router.get("/auth/oauth/{provider_id}")
Expand Down Expand Up @@ -640,7 +644,7 @@ async def oauth_callback(
provider_id, token, raw_user_data, default_user
)

response = await _authenticate_user(user, redirect_to_callback=True)
response = await _authenticate_user(request, user, redirect_to_callback=True)

clear_oauth_state_cookie(response)

Expand Down Expand Up @@ -689,7 +693,7 @@ async def oauth_azure_hf_callback(
provider_id, token, raw_user_data, default_user, id_token
)

response = await _authenticate_user(user, redirect_to_callback=True)
response = await _authenticate_user(request, user, redirect_to_callback=True)

clear_oauth_state_cookie(response)

Expand Down
8 changes: 6 additions & 2 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from starlette.requests import cookie_parser
from typing_extensions import TypeAlias

from chainlit.auth import get_current_user, require_login
from chainlit.auth import (
get_current_user,
get_token_from_cookies,
require_login,
)
from chainlit.chat_context import chat_context
from chainlit.config import config
from chainlit.context import init_ws_context
Expand Down Expand Up @@ -83,7 +87,7 @@ def load_user_env(user_env):
def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]:
if cookie_header := environ.get("HTTP_COOKIE", None):
cookies = cookie_parser(cookie_header)
return cookies.get("access_token", None)
return get_token_from_cookies(cookies)

return None

Expand Down
Empty file added backend/tests/auth/__init__.py
Empty file.
148 changes: 148 additions & 0 deletions backend/tests/auth/test_cookie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import pytest
from fastapi import FastAPI, Form
from fastapi.testclient import TestClient
from starlette.requests import Request
from starlette.responses import Response

from chainlit.auth import (
clear_auth_cookie,
get_token_from_cookies,
set_auth_cookie,
)


@pytest.fixture
def test_app():
app = FastAPI()

@app.post("/set-cookie")
async def set_cookie_endpoint(request: Request, token: str = Form()):
response = Response()
set_auth_cookie(request, response, token)
return response

@app.get("/get-token")
async def get_token_endpoint(request: Request):
token = get_token_from_cookies(request.cookies)
return {"token": token}

@app.delete("/clear-cookie")
async def clear_cookie_endpoint(request: Request):
response = Response()
clear_auth_cookie(request, response)
return response

return app


@pytest.fixture
def client(test_app):
return TestClient(test_app)


def test_short_token(client):
"""Test with a <3000 shorter token."""

# Set a short token
short_token = "x" * 1000
set_response = client.post("/set-cookie", data={"token": short_token})
assert set_response.status_code == 200

# Verify cookies were set
cookies = set_response.cookies
assert cookies, "No cookies set"
assert "access_token" in cookies, f"No chunking for short cookies: {cookies}"

# Read back the token using client's cookie jar
get_response = client.get("/get-token")
assert get_response.status_code == 200
assert get_response.json()["token"] == short_token


def test_set_and_read_4kb_token(client):
"""Test full cookie lifecycle using actual client cookie handling."""

# Set a 4KB token
token_4kb = "x" * 4000
set_response = client.post("/set-cookie", data={"token": token_4kb})
assert set_response.status_code == 200

# Verify cookies were set
cookies = set_response.cookies
assert f"{cookies.keys()} should contain chunked cookies", any(
key.startswith("access_token_") for key in cookies.keys()
)

# Read back the token using client's cookie jar
get_response = client.get("/get-token")
assert get_response.status_code == 200

response_token = get_response.json()["token"]
assert len(response_token) == len(token_4kb)
assert response_token == token_4kb


def test_overwrite_shorter_token_chunked(client):
"""Test cookie chunk cleanup when replacing a large token with a smaller one."""
# Set initial long token
long_token = "LONG" * 2000 # 8000 characters
client.post("/set-cookie", data={"token": long_token})

# Verify initial chunks exist
first_cookies = client.cookies
assert len([k for k in first_cookies if k.startswith("access_token_")]) > 1

# Set shorter token (should clear previous chunks)
short_token = "SHORT" * 1000 # 4000 characters
client.post("/set-cookie", data={"token": short_token})

# Verify new cookie state
final_response = client.get("/get-token")
assert final_response.json()["token"] == short_token

# Verify only two chunks remain
final_cookies = client.cookies
chunk_cookies = [k for k in final_cookies if k.startswith("access_token_")]
assert len(chunk_cookies) == 2, f"Found {len(chunk_cookies)} residual cookies"


def test_overwrite_shorter_token_unchunked(client):
"""Test cookie chunk cleanup when replacing a large token with a smaller one."""
# Set initial long token
long_token = "LONG" * 1000 # 4000 characters
client.post("/set-cookie", data={"token": long_token})

# Verify initial chunks exist
first_cookies = client.cookies
assert len([k for k in first_cookies if k.startswith("access_token_")]) > 1

# Set shorter token (should clear previous chunks)
short_token = "SHORT"
client.post("/set-cookie", data={"token": short_token})

# Verify new cookie state
final_response = client.get("/get-token")
assert final_response.json()["token"] == short_token

# Verify no chunks remain
final_cookies = client.cookies
chunk_cookies = [k for k in final_cookies if k.startswith("access_token_")]
assert len(chunk_cookies) == 0, f"Found {len(chunk_cookies)} residual cookies"


def test_clear_auth_cookie(client):
"""Test cookie clearing removes all chunks."""
# Set initial token
client.post("/set-cookie", data={"token": "x" * 4000})

# Verify cookies exist
assert len(client.cookies) > 0

# Clear cookies
clear_response = client.delete("/clear-cookie")
assert clear_response.status_code == 200

# Verify cookies were cleared
assert len(clear_response.cookies) == 0
final_response = client.get("/get-token")
assert final_response.json()["token"] is None

0 comments on commit 85c98ea

Please sign in to comment.