Skip to content

Commit

Permalink
WIP websockets denial response extension
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 26, 2023
1 parent bf00ada commit 64d6eb7
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 53 deletions.
8 changes: 4 additions & 4 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,8 +1369,7 @@ async def test_server_multiple_websocket_http_response_start_events(
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
nonlocal exception_message
assert scope["type"] == "websocket"
assert "extensions" in scope
assert "websocket.http.response" in scope["extensions"]
assert "websocket.http.response" in scope.get("extensions", {})

# Pull up first recv message.
message = await receive()
Expand All @@ -1385,13 +1384,14 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
try:
await send(start_event)
except Exception as exc:
print(exc)
exception_message = str(exc)

async def websocket_session(url: str):
with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info:
async with websockets.client.connect(url):
pass
assert exc_info.value.status_code == 404
assert exc_info.value.status_code in (404, 500)

config = Config(
app=app,
Expand Down Expand Up @@ -1564,7 +1564,7 @@ async def open_connection(url: str):
)
async with run_server(config):
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"]
assert headers.get_all("server") == ["uvicorn", "over-ridden", "another-value"]


@pytest.mark.anyio
Expand Down
9 changes: 6 additions & 3 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,11 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
self.extra_headers.extend(
# ASGI spec requires bytes
# But for compatibility we need to convert it to strings
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in message["headers"]
(
name.decode("latin-1").lower(),
value.decode("latin-1").lower(),
)
for name, value in list(message.get("headers", []))
)
self.handshake_started_event.set()

Expand All @@ -317,7 +320,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
# websockets requires the status to be an enum. look it up.
status = http.HTTPStatus(message["status"])
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
(name.decode("latin-1").lower(), value.decode("latin-1").lower())
for name, value in message.get("headers", [])
]
self.initial_response = (status, headers, b"")
Expand Down
104 changes: 59 additions & 45 deletions uvicorn/protocols/websockets/websockets_sansio_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
WebSocketCloseEvent,
WebSocketDisconnectEvent,
WebSocketReceiveEvent,
WebSocketResponseBodyEvent,
WebSocketResponseStartEvent,
WebSocketScope,
WebSocketSendEvent,
)
Expand Down Expand Up @@ -67,6 +69,7 @@ def __init__(
self.handshake_initiated = False
self.handshake_complete = False
self.close_sent = False
self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None

extensions = []
if self.config.ws_per_message_deflate:
Expand Down Expand Up @@ -177,6 +180,7 @@ def handle_connect(self, event: Request) -> None:
"headers": headers,
"subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"),
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
self.queue.put_nowait({"type": "websocket.connect"})
task = self.loop.create_task(self.run_asgi())
Expand Down Expand Up @@ -262,89 +266,80 @@ async def run_asgi(self) -> None:
self.transport.close()

def send_500_response(self) -> None:
msg = b"Internal Server Error"
content = [
b"HTTP/1.1 500 Internal Server Error\r\n"
b"content-type: text/plain; charset=utf-8\r\n",
b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n",
b"connection: close\r\n",
b"\r\n",
msg,
]
self.transport.write(b"".join(content))
response = self.conn.reject(500, "Internal Server Error")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.transport.writelines(output)

async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()

message_type = message["type"]

if not self.handshake_complete:
if message_type == "websocket.accept" and not self.transport.is_closing():
if not self.handshake_complete or (
self.handshake_complete and self.initial_response is None
):
if message_type == "websocket.accept":
message = cast(WebSocketAcceptEvent, message)
self.logger.info(
'%s - "WebSocket %s" [accepted]',
self.scope["client"],
get_path_with_query_string(self.scope),
)
headers = [
(
key.decode("ascii"),
value.decode("ascii", errors="surrogateescape"),
(name.decode("latin-1").lower(), value.decode("latin-1").lower())
for name, value in (
self.default_headers + list(message.get("headers", []))
)
for key, value in self.default_headers
+ list(message.get("headers", []))
]

accepted_subprotocol = message.get("subprotocol")
if accepted_subprotocol:
headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol))

self.handshake_complete = True
self.response.headers.update(headers)
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.writelines(output)

elif message_type == "websocket.close" and not self.transport.is_closing():
if not self.transport.is_closing():
self.handshake_complete = True
self.conn.send_response(self.response)
output = self.conn.data_to_send()
self.transport.writelines(output)

elif message_type == "websocket.close":
message = cast(WebSocketCloseEvent, message)
self.queue.put_nowait(
{
"type": "websocket.disconnect",
"code": message.get("code", 1000) or 1000,
}
)
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.logger.info(
'%s - "WebSocket %s" 403',
self.scope["client"],
get_path_with_query_string(self.scope),
)
extra_headers = [
(
key.decode("ascii"),
value.decode("ascii", errors="surrogateescape"),
)
for key, value in self.default_headers
]

response = self.conn.reject(
HTTPStatus.FORBIDDEN, message.get("reason", "") or ""
)
response.headers.update(extra_headers)
response = self.conn.reject(HTTPStatus.FORBIDDEN, "")
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.handshake_complete = True
self.transport.writelines(output)
self.transport.close()

elif message_type == "websocket.http.response.start":
message = cast(WebSocketResponseStartEvent, message)
self.logger.info(
'%s - "WebSocket %s" %d',
self.scope["client"],
get_path_with_query_string(self.scope),
message["status"],
)
headers = [
(name.decode("latin-1"), value.decode("latin-1"))
for name, value in list(message.get("headers", []))
]
self.initial_response = (message["status"], headers, b"")
else:
msg = (
"Expected ASGI message 'websocket.accept' or 'websocket.close', "
"Expected ASGI message 'websocket.accept', 'websocket.close' "
"or 'websocket.http.response.start' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)

elif not self.close_sent:
elif not self.close_sent and self.initial_response is None:
if message_type == "websocket.send" and not self.transport.is_closing():
message = cast(WebSocketSendEvent, message)
bytes_data = message.get("bytes")
Expand Down Expand Up @@ -372,6 +367,25 @@ async def send(self, message: ASGISendEvent) -> None:
" but got '%s'."
)
raise RuntimeError(msg % message_type)
elif self.initial_response is not None:
if message_type == "websocket.http.response.body":
message = cast(WebSocketResponseBodyEvent, message)
body = self.initial_response[2] + message["body"]
self.initial_response = self.initial_response[:2] + (body,)
if not message.get("more_body", False):
response = self.conn.reject(self.initial_response[0], body.decode())
response.headers.update(self.initial_response[1])
self.conn.send_response(response)
output = self.conn.data_to_send()
self.close_sent = True
self.transport.writelines(output)
self.transport.close()
else:
msg = (
"Expected ASGI message 'websocket.http.response.body' "
"but got '%s'."
)
raise RuntimeError(msg % message_type)

else:
msg = "Unexpected ASGI message '%s', after sending 'websocket.close'."
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def shutdown(self) -> None:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)

# Event handlers
Expand Down

0 comments on commit 64d6eb7

Please sign in to comment.