Skip to content

Commit

Permalink
fix(playground): authenticate websockets (#4924)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored and mikeldking committed Oct 11, 2024
1 parent 656a04a commit d887b30
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
33 changes: 30 additions & 3 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
from starlette.datastructures import State as StarletteState
from starlette.exceptions import HTTPException
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.staticfiles import StaticFiles
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.templating import Jinja2Templates
from starlette.types import Scope, StatefulLifespan
from starlette.websockets import WebSocket
from strawberry.extensions import SchemaExtension
from strawberry.fastapi import GraphQLRouter
from strawberry.schema import BaseSchema
Expand Down Expand Up @@ -632,6 +633,29 @@ async def plain_text_http_exception_handler(request: Request, exc: HTTPException
return PlainTextResponse(str(exc.detail), status_code=exc.status_code, headers=headers)


async def websocket_denial_response_handler(websocket: WebSocket, exc: WebSocketException) -> None:
"""
Overrides the default exception handler for WebSocketException to ensure
that the HTTP response returned when a WebSocket connection is denied has
the same status code as the raised exception. This is in keeping with the
WebSocket Denial Response Extension of the ASGI specificiation described
below.
"Websocket connections start with the client sending a HTTP request
containing the appropriate upgrade headers. On receipt of this request a
server can choose to either upgrade the connection or respond with an HTTP
response (denying the upgrade). The core ASGI specification does not allow
for any control over the denial response, instead specifying that the HTTP
status code 403 should be returned, whereas this extension allows an ASGI
framework to control the denial response."
For details, see:
- https://asgi.readthedocs.io/en/latest/extensions.html#websocket-denial-response
"""
assert isinstance(exc, WebSocketException)
await websocket.send_denial_response(JSONResponse(status_code=exc.code, content=exc.reason))


def create_app(
db: DbSessionFactory,
export_path: Path,
Expand Down Expand Up @@ -778,7 +802,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
scaffolder_config=scaffolder_config,
),
middleware=middlewares,
exception_handlers={HTTPException: plain_text_http_exception_handler},
exception_handlers={
HTTPException: plain_text_http_exception_handler,
WebSocketException: websocket_denial_response_handler, # type: ignore[dict-item]
},
debug=debug,
swagger_ui_parameters={
"defaultModelsExpandDepth": -1, # hides the schema section in the Swagger UI
Expand Down
16 changes: 12 additions & 4 deletions src/phoenix/server/bearer_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
Callable,
Optional,
Tuple,
cast,
)

import grpc
from fastapi import HTTPException, Request
from fastapi import HTTPException, Request, WebSocket, WebSocketException
from grpc_interceptor import AsyncServerInterceptor
from grpc_interceptor.exceptions import Unauthenticated
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
Expand Down Expand Up @@ -116,12 +117,19 @@ async def intercept(
raise Unauthenticated()


async def is_authenticated(request: Request) -> None:
async def is_authenticated(
# fastapi dependencies require non-optional types
request: Request = cast(Request, None),
websocket: WebSocket = cast(WebSocket, None),
) -> None:
"""
Raises a 401 if the request is not authenticated.
Raises a 401 if the request or websocket connection is not authenticated.
"""
if not isinstance((user := request.user), PhoenixUser):
assert request or websocket
if request and not isinstance((user := request.user), PhoenixUser):
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
if websocket and not isinstance((user := websocket.user), PhoenixUser):
raise WebSocketException(code=HTTP_401_UNAUTHORIZED, reason="Invalid token")
claims = user.claims
if claims.status is ClaimSetStatus.EXPIRED:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token")
Expand Down

0 comments on commit d887b30

Please sign in to comment.