From cb2fbd0a33bc5714cf9cec36cf56357479100d1c Mon Sep 17 00:00:00 2001 From: ibraheem-opentensor <165814940+ibraheem-opentensor@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:12:40 -0700 Subject: [PATCH] Revert "fix streaming synapse regression" --- bittensor/axon.py | 78 ++++++------------------- bittensor/stream.py | 14 +---- tests/unit_tests/test_axon.py | 107 +--------------------------------- 3 files changed, 22 insertions(+), 177 deletions(-) diff --git a/bittensor/axon.py b/bittensor/axon.py index 551be165bb..ca06335307 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -31,7 +31,6 @@ import traceback import typing import uuid -import warnings from inspect import signature, Signature, Parameter from typing import List, Optional, Tuple, Callable, Any, Dict, Awaitable @@ -485,50 +484,17 @@ def verify_custom(synapse: MyCustomSynapse): async def endpoint(*args, **kwargs): start_time = time.time() - response = forward_fn(*args, **kwargs) - if isinstance(response, Awaitable): - response = await response - if isinstance(response, bittensor.Synapse): - return await self.middleware_cls.synapse_to_response( - synapse=response, start_time=start_time - ) - else: - response_synapse = getattr(response, "synapse", None) - if response_synapse is None: - warnings.warn( - "The response synapse is None. The input synapse will be used as the response synapse. " - "Reliance on forward_fn modifying input synapse as a side-effects is deprecated. " - "Explicitly set `synapse` on response object instead.", - DeprecationWarning, - ) - # Replace with `return response` in next major version - response_synapse = args[0] - - return await self.middleware_cls.synapse_to_response( - synapse=response_synapse, - start_time=start_time, - response_override=response, - ) - - return_annotation = forward_sig.return_annotation - - if isinstance(return_annotation, type) and issubclass( - return_annotation, bittensor.Synapse - ): - if issubclass( - return_annotation, - bittensor.StreamingSynapse, - ): - warnings.warn( - "The forward_fn return annotation is a subclass of bittensor.StreamingSynapse. " - "Most likely the correct return annotation would be BTStreamingResponse." - ) - else: - return_annotation = JSONResponse + response_synapse = forward_fn(*args, **kwargs) + if isinstance(response_synapse, Awaitable): + response_synapse = await response_synapse + return await self.middleware_cls.synapse_to_response( + synapse=response_synapse, start_time=start_time + ) + # replace the endpoint signature, but set return annotation to JSONResponse endpoint.__signature__ = Signature( # type: ignore parameters=list(forward_sig.parameters.values()), - return_annotation=return_annotation, + return_annotation=JSONResponse, ) # Add the endpoint to the router, making it available on both GET and POST methods @@ -1454,21 +1420,14 @@ async def run( @classmethod async def synapse_to_response( - cls, - synapse: bittensor.Synapse, - start_time: float, - *, - response_override: Optional[Response] = None, - ) -> Response: + cls, synapse: bittensor.Synapse, start_time: float + ) -> JSONResponse: """ Converts the Synapse object into a JSON response with HTTP headers. Args: - synapse: The Synapse object representing the request. - start_time: The timestamp when the request processing started. - response_override: - Instead of serializing the synapse, mutate the provided response object. - This is only really useful for StreamingSynapse responses. + synapse (bittensor.Synapse): The Synapse object representing the request. + start_time (float): The timestamp when the request processing started. Returns: Response: The final HTTP response, with updated headers, ready to be sent back to the client. @@ -1487,14 +1446,11 @@ async def synapse_to_response( synapse.axon.process_time = time.time() - start_time - if response_override: - response = response_override - else: - serialized_synapse = await serialize_response(response_content=synapse) - response = JSONResponse( - status_code=synapse.axon.status_code, - content=serialized_synapse, - ) + serialized_synapse = await serialize_response(response_content=synapse) + response = JSONResponse( + status_code=synapse.axon.status_code, + content=serialized_synapse, + ) try: updated_headers = synapse.to_headers() diff --git a/bittensor/stream.py b/bittensor/stream.py index 3a82edc15a..e0dc17c42c 100644 --- a/bittensor/stream.py +++ b/bittensor/stream.py @@ -1,5 +1,3 @@ -import typing - from aiohttp import ClientResponse import bittensor @@ -51,24 +49,16 @@ class BTStreamingResponse(_StreamingResponse): provided by the subclass. """ - def __init__( - self, - model: BTStreamingResponseModel, - *, - synapse: typing.Optional["StreamingSynapse"] = None, - **kwargs, - ): + def __init__(self, model: BTStreamingResponseModel, **kwargs): """ Initializes the BTStreamingResponse with the given token streamer model. Args: model: A BTStreamingResponseModel instance containing the token streamer callable, which is responsible for generating the content of the response. - synapse: The response Synapse to be used to update the response headers etc. **kwargs: Additional keyword arguments passed to the parent StreamingResponse class. """ super().__init__(content=iter(()), **kwargs) self.token_streamer = model.token_streamer - self.synapse = synapse async def stream_response(self, send: Send): """ @@ -149,4 +139,4 @@ def create_streaming_response( """ model_instance = BTStreamingResponseModel(token_streamer=token_streamer) - return self.BTStreamingResponse(model_instance, synapse=self) + return self.BTStreamingResponse(model_instance) diff --git a/tests/unit_tests/test_axon.py b/tests/unit_tests/test_axon.py index 050b9ae915..cfb46c32c2 100644 --- a/tests/unit_tests/test_axon.py +++ b/tests/unit_tests/test_axon.py @@ -22,21 +22,20 @@ import re from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from unittest import IsolatedAsyncioTestCase from unittest.mock import AsyncMock, MagicMock, patch # Third Party -import fastapi import netaddr -import pydantic + import pytest from starlette.requests import Request from fastapi.testclient import TestClient # Bittensor import bittensor -from bittensor import Synapse, RunException, StreamingSynapse +from bittensor import Synapse, RunException from bittensor.axon import AxonMiddleware from bittensor.axon import axon as Axon @@ -533,39 +532,6 @@ def http_client(self, axon): async def no_verify_fn(self, synapse): return - class NonDeterministicHeaders(pydantic.BaseModel): - """ - Helper class to verify headers. - - Size headers are non-determistic as for example, header_size depends on non-deterministic - processing-time value. - """ - - bt_header_axon_process_time: float = pydantic.Field(gt=0, lt=30) - timeout: float = pydantic.Field(gt=0, lt=30) - header_size: int = pydantic.Field(None, gt=10, lt=400) - total_size: int = pydantic.Field(gt=100, lt=10000) - content_length: Optional[int] = pydantic.Field( - None, alias="content-length", gt=100, lt=10000 - ) - - def assert_headers(self, response, expected_headers): - expected_headers = { - "bt_header_axon_status_code": "200", - "bt_header_axon_status_message": "Success", - **expected_headers, - } - headers = dict(response.headers) - non_deterministic_headers_names = { - field.alias or field_name - for field_name, field in self.NonDeterministicHeaders.model_fields.items() - } - non_deterministic_headers = { - field: headers.pop(field, None) for field in non_deterministic_headers_names - } - assert headers == expected_headers - self.NonDeterministicHeaders.model_validate(non_deterministic_headers) - async def test_unknown_path(self, http_client): response = http_client.get("/no_such_path") assert (response.status_code, response.json()) == ( @@ -591,14 +557,6 @@ async def test_ping__without_verification(self, http_client, axon): assert response.status_code == 200 response_synapse = Synapse(**response.json()) assert response_synapse.axon.status_code == 200 - self.assert_headers( - response, - { - "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a", - "content-type": "application/json", - "name": "Synapse", - }, - ) @pytest.fixture def custom_synapse_cls(self): @@ -607,17 +565,6 @@ class CustomSynapse(Synapse): return CustomSynapse - @pytest.fixture - def streaming_synapse_cls(self): - class CustomStreamingSynapse(StreamingSynapse): - async def process_streaming_response(self, response): - pass - - def extract_response_json(self, response) -> dict: - return {} - - return CustomStreamingSynapse - async def test_synapse__explicitly_set_status_code( self, http_client, axon, custom_synapse_cls, no_verify_axon ): @@ -666,51 +613,3 @@ async def forward_fn(synapse: custom_synapse_cls): response_data = response.json() assert sorted(response_data.keys()) == ["message"] assert re.match(r"Internal Server Error #[\da-f\-]+", response_data["message"]) - - @pytest.mark.parametrize( - "forward_fn_return_annotation", - [ - None, - fastapi.Response, - bittensor.StreamingSynapse, - ], - ) - async def test_streaming_synapse( - self, - http_client, - axon, - streaming_synapse_cls, - no_verify_axon, - forward_fn_return_annotation, - ): - tokens = [f"data{i}\n" for i in range(10)] - - async def streamer(send): - for token in tokens: - await send( - { - "type": "http.response.body", - "body": token.encode(), - "more_body": True, - } - ) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - - async def forward_fn(synapse: streaming_synapse_cls): - return synapse.create_streaming_response(token_streamer=streamer) - - if forward_fn_return_annotation is not None: - forward_fn.__annotations__["return"] = forward_fn_return_annotation - - axon.attach(forward_fn) - - response = http_client.post_synapse(streaming_synapse_cls()) - assert (response.status_code, response.text) == (200, "".join(tokens)) - self.assert_headers( - response, - { - "content-type": "text/event-stream", - "name": "CustomStreamingSynapse", - "computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a", - }, - )