From 64d6eb707fac78568aa03f7610fa4401334af4c4 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 26 Dec 2023 14:00:07 +0100 Subject: [PATCH] WIP websockets denial response extension --- tests/protocols/test_websocket.py | 8 +- .../protocols/websockets/websockets_impl.py | 9 +- .../websockets/websockets_sansio_impl.py | 104 ++++++++++-------- uvicorn/protocols/websockets/wsproto_impl.py | 2 +- 4 files changed, 70 insertions(+), 53 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 015c5dbdb..a2363e46f 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1369,8 +1369,7 @@ async def test_server_multiple_websocket_http_response_start_events( async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal exception_message assert scope["type"] == "websocket" - assert "extensions" in scope - assert "websocket.http.response" in scope["extensions"] + assert "websocket.http.response" in scope.get("extensions", {}) # Pull up first recv message. message = await receive() @@ -1385,13 +1384,14 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable try: await send(start_event) except Exception as exc: + print(exc) exception_message = str(exc) async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass - assert exc_info.value.status_code == 404 + assert exc_info.value.status_code in (404, 500) config = Config( app=app, @@ -1564,7 +1564,7 @@ async def open_connection(url: str): ) async with run_server(config): headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") - assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] + assert headers.get_all("server") == ["uvicorn", "over-ridden", "another-value"] @pytest.mark.anyio diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 3f04c1dd5..46bcabcdc 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -290,8 +290,11 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.extra_headers.extend( # ASGI spec requires bytes # But for compatibility we need to convert it to strings - (name.decode("latin-1"), value.decode("latin-1")) - for name, value in message["headers"] + ( + name.decode("latin-1").lower(), + value.decode("latin-1").lower(), + ) + for name, value in list(message.get("headers", [])) ) self.handshake_started_event.set() @@ -317,7 +320,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: # websockets requires the status to be an enum. look it up. status = http.HTTPStatus(message["status"]) headers = [ - (name.decode("latin-1"), value.decode("latin-1")) + (name.decode("latin-1").lower(), value.decode("latin-1").lower()) for name, value in message.get("headers", []) ] self.initial_response = (status, headers, b"") diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a75243602..2960c7af9 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -19,6 +19,8 @@ WebSocketCloseEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, WebSocketScope, WebSocketSendEvent, ) @@ -67,6 +69,7 @@ def __init__( self.handshake_initiated = False self.handshake_complete = False self.close_sent = False + self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None extensions = [] if self.config.ws_per_message_deflate: @@ -177,6 +180,7 @@ def handle_connect(self, event: Request) -> None: "headers": headers, "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) @@ -262,24 +266,20 @@ async def run_asgi(self) -> None: self.transport.close() def send_500_response(self) -> None: - msg = b"Internal Server Error" - content = [ - b"HTTP/1.1 500 Internal Server Error\r\n" - b"content-type: text/plain; charset=utf-8\r\n", - b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", - b"connection: close\r\n", - b"\r\n", - msg, - ] - self.transport.write(b"".join(content)) + response = self.conn.reject(500, "Internal Server Error") + self.conn.send_response(response) + output = self.conn.data_to_send() + self.transport.writelines(output) async def send(self, message: ASGISendEvent) -> None: await self.writable.wait() message_type = message["type"] - if not self.handshake_complete: - if message_type == "websocket.accept" and not self.transport.is_closing(): + if not self.handshake_complete or ( + self.handshake_complete and self.initial_response is None + ): + if message_type == "websocket.accept": message = cast(WebSocketAcceptEvent, message) self.logger.info( '%s - "WebSocket %s" [accepted]', @@ -287,64 +287,59 @@ async def send(self, message: ASGISendEvent) -> None: get_path_with_query_string(self.scope), ) headers = [ - ( - key.decode("ascii"), - value.decode("ascii", errors="surrogateescape"), + (name.decode("latin-1").lower(), value.decode("latin-1").lower()) + for name, value in ( + self.default_headers + list(message.get("headers", [])) ) - for key, value in self.default_headers - + list(message.get("headers", [])) ] - accepted_subprotocol = message.get("subprotocol") if accepted_subprotocol: headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) - - self.handshake_complete = True self.response.headers.update(headers) - self.conn.send_response(self.response) - output = self.conn.data_to_send() - self.transport.writelines(output) - elif message_type == "websocket.close" and not self.transport.is_closing(): + if not self.transport.is_closing(): + self.handshake_complete = True + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close": message = cast(WebSocketCloseEvent, message) - self.queue.put_nowait( - { - "type": "websocket.disconnect", - "code": message.get("code", 1000) or 1000, - } - ) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"], get_path_with_query_string(self.scope), ) - extra_headers = [ - ( - key.decode("ascii"), - value.decode("ascii", errors="surrogateescape"), - ) - for key, value in self.default_headers - ] - - response = self.conn.reject( - HTTPStatus.FORBIDDEN, message.get("reason", "") or "" - ) - response.headers.update(extra_headers) + response = self.conn.reject(HTTPStatus.FORBIDDEN, "") self.conn.send_response(response) output = self.conn.data_to_send() self.close_sent = True self.handshake_complete = True self.transport.writelines(output) self.transport.close() - + elif message_type == "websocket.http.response.start": + message = cast(WebSocketResponseStartEvent, message) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + headers = [ + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in list(message.get("headers", [])) + ] + self.initial_response = (message["status"], headers, b"") else: msg = ( - "Expected ASGI message 'websocket.accept' or 'websocket.close', " + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " "but got '%s'." ) raise RuntimeError(msg % message_type) - elif not self.close_sent: + elif not self.close_sent and self.initial_response is None: if message_type == "websocket.send" and not self.transport.is_closing(): message = cast(WebSocketSendEvent, message) bytes_data = message.get("bytes") @@ -372,6 +367,25 @@ async def send(self, message: ASGISendEvent) -> None: " but got '%s'." ) raise RuntimeError(msg % message_type) + elif self.initial_response is not None: + if message_type == "websocket.http.response.body": + message = cast(WebSocketResponseBodyEvent, message) + body = self.initial_response[2] + message["body"] + self.initial_response = self.initial_response[:2] + (body,) + if not message.get("more_body", False): + response = self.conn.reject(self.initial_response[0], body.decode()) + response.headers.update(self.initial_response[1]) + self.conn.send_response(response) + output = self.conn.data_to_send() + self.close_sent = True + self.transport.writelines(output) + self.transport.close() + else: + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) else: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 1c17ee068..db7967f45 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -155,7 +155,7 @@ def shutdown(self) -> None: self.send_500_response() self.transport.close() - def on_task_complete(self, task: asyncio.Task) -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) # Event handlers