Skip to content

Commit

Permalink
Merge pull request opentensor#2159 from backend-developers-ltd/fix_st…
Browse files Browse the repository at this point in the history
…reaming_synapse

fix streaming synapse regression
  • Loading branch information
ibraheem-opentensor authored Jul 25, 2024
2 parents efa6239 + d279f7f commit 7fc1e4b
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 22 deletions.
78 changes: 61 additions & 17 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import traceback
import typing
import uuid
import warnings
from inspect import signature, Signature, Parameter
from typing import List, Optional, Tuple, Callable, Any, Dict, Awaitable

Expand Down Expand Up @@ -484,17 +485,50 @@ def verify_custom(synapse: MyCustomSynapse):

async def endpoint(*args, **kwargs):
start_time = time.time()
response_synapse = forward_fn(*args, **kwargs)
if isinstance(response_synapse, Awaitable):
response_synapse = await response_synapse
return await self.middleware_cls.synapse_to_response(
synapse=response_synapse, start_time=start_time
)
response = forward_fn(*args, **kwargs)
if isinstance(response, Awaitable):
response = await response
if isinstance(response, bittensor.Synapse):
return await self.middleware_cls.synapse_to_response(
synapse=response, start_time=start_time
)
else:
response_synapse = getattr(response, "synapse", None)
if response_synapse is None:
warnings.warn(
"The response synapse is None. The input synapse will be used as the response synapse. "
"Reliance on forward_fn modifying input synapse as a side-effects is deprecated. "
"Explicitly set `synapse` on response object instead.",
DeprecationWarning,
)
# Replace with `return response` in next major version
response_synapse = args[0]

return await self.middleware_cls.synapse_to_response(
synapse=response_synapse,
start_time=start_time,
response_override=response,
)

return_annotation = forward_sig.return_annotation

if isinstance(return_annotation, type) and issubclass(
return_annotation, bittensor.Synapse
):
if issubclass(
return_annotation,
bittensor.StreamingSynapse,
):
warnings.warn(
"The forward_fn return annotation is a subclass of bittensor.StreamingSynapse. "
"Most likely the correct return annotation would be BTStreamingResponse."
)
else:
return_annotation = JSONResponse

# replace the endpoint signature, but set return annotation to JSONResponse
endpoint.__signature__ = Signature( # type: ignore
parameters=list(forward_sig.parameters.values()),
return_annotation=JSONResponse,
return_annotation=return_annotation,
)

# Add the endpoint to the router, making it available on both GET and POST methods
Expand Down Expand Up @@ -1420,14 +1454,21 @@ async def run(

@classmethod
async def synapse_to_response(
cls, synapse: bittensor.Synapse, start_time: float
) -> JSONResponse:
cls,
synapse: bittensor.Synapse,
start_time: float,
*,
response_override: Optional[Response] = None,
) -> Response:
"""
Converts the Synapse object into a JSON response with HTTP headers.
Args:
synapse (bittensor.Synapse): The Synapse object representing the request.
start_time (float): The timestamp when the request processing started.
synapse: The Synapse object representing the request.
start_time: The timestamp when the request processing started.
response_override:
Instead of serializing the synapse, mutate the provided response object.
This is only really useful for StreamingSynapse responses.
Returns:
Response: The final HTTP response, with updated headers, ready to be sent back to the client.
Expand All @@ -1446,11 +1487,14 @@ async def synapse_to_response(

synapse.axon.process_time = time.time() - start_time

serialized_synapse = await serialize_response(response_content=synapse)
response = JSONResponse(
status_code=synapse.axon.status_code,
content=serialized_synapse,
)
if response_override:
response = response_override
else:
serialized_synapse = await serialize_response(response_content=synapse)
response = JSONResponse(
status_code=synapse.axon.status_code,
content=serialized_synapse,
)

try:
updated_headers = synapse.to_headers()
Expand Down
14 changes: 12 additions & 2 deletions bittensor/stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

from aiohttp import ClientResponse
import bittensor

Expand Down Expand Up @@ -49,16 +51,24 @@ class BTStreamingResponse(_StreamingResponse):
provided by the subclass.
"""

def __init__(self, model: BTStreamingResponseModel, **kwargs):
def __init__(
self,
model: BTStreamingResponseModel,
*,
synapse: typing.Optional["StreamingSynapse"] = None,
**kwargs,
):
"""
Initializes the BTStreamingResponse with the given token streamer model.
Args:
model: A BTStreamingResponseModel instance containing the token streamer callable, which is responsible for generating the content of the response.
synapse: The response Synapse to be used to update the response headers etc.
**kwargs: Additional keyword arguments passed to the parent StreamingResponse class.
"""
super().__init__(content=iter(()), **kwargs)
self.token_streamer = model.token_streamer
self.synapse = synapse

async def stream_response(self, send: Send):
"""
Expand Down Expand Up @@ -139,4 +149,4 @@ def create_streaming_response(
"""
model_instance = BTStreamingResponseModel(token_streamer=token_streamer)

return self.BTStreamingResponse(model_instance)
return self.BTStreamingResponse(model_instance, synapse=self)
107 changes: 104 additions & 3 deletions tests/unit_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@
import re
from dataclasses import dataclass

from typing import Any
from typing import Any, Optional
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

# Third Party
import fastapi
import netaddr

import pydantic
import pytest
from starlette.requests import Request
from fastapi.testclient import TestClient

# Bittensor
import bittensor
from bittensor import Synapse, RunException
from bittensor import Synapse, RunException, StreamingSynapse
from bittensor.axon import AxonMiddleware
from bittensor.axon import axon as Axon

Expand Down Expand Up @@ -532,6 +533,39 @@ def http_client(self, axon):
async def no_verify_fn(self, synapse):
return

class NonDeterministicHeaders(pydantic.BaseModel):
"""
Helper class to verify headers.
Size headers are non-determistic as for example, header_size depends on non-deterministic
processing-time value.
"""

bt_header_axon_process_time: float = pydantic.Field(gt=0, lt=30)
timeout: float = pydantic.Field(gt=0, lt=30)
header_size: int = pydantic.Field(None, gt=10, lt=400)
total_size: int = pydantic.Field(gt=100, lt=10000)
content_length: Optional[int] = pydantic.Field(
None, alias="content-length", gt=100, lt=10000
)

def assert_headers(self, response, expected_headers):
expected_headers = {
"bt_header_axon_status_code": "200",
"bt_header_axon_status_message": "Success",
**expected_headers,
}
headers = dict(response.headers)
non_deterministic_headers_names = {
field.alias or field_name
for field_name, field in self.NonDeterministicHeaders.model_fields.items()
}
non_deterministic_headers = {
field: headers.pop(field, None) for field in non_deterministic_headers_names
}
assert headers == expected_headers
self.NonDeterministicHeaders.model_validate(non_deterministic_headers)

async def test_unknown_path(self, http_client):
response = http_client.get("/no_such_path")
assert (response.status_code, response.json()) == (
Expand All @@ -557,6 +591,14 @@ async def test_ping__without_verification(self, http_client, axon):
assert response.status_code == 200
response_synapse = Synapse(**response.json())
assert response_synapse.axon.status_code == 200
self.assert_headers(
response,
{
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
"content-type": "application/json",
"name": "Synapse",
},
)

@pytest.fixture
def custom_synapse_cls(self):
Expand All @@ -565,6 +607,17 @@ class CustomSynapse(Synapse):

return CustomSynapse

@pytest.fixture
def streaming_synapse_cls(self):
class CustomStreamingSynapse(StreamingSynapse):
async def process_streaming_response(self, response):
pass

def extract_response_json(self, response) -> dict:
return {}

return CustomStreamingSynapse

async def test_synapse__explicitly_set_status_code(
self, http_client, axon, custom_synapse_cls, no_verify_axon
):
Expand Down Expand Up @@ -613,3 +666,51 @@ async def forward_fn(synapse: custom_synapse_cls):
response_data = response.json()
assert sorted(response_data.keys()) == ["message"]
assert re.match(r"Internal Server Error #[\da-f\-]+", response_data["message"])

@pytest.mark.parametrize(
"forward_fn_return_annotation",
[
None,
fastapi.Response,
bittensor.StreamingSynapse,
],
)
async def test_streaming_synapse(
self,
http_client,
axon,
streaming_synapse_cls,
no_verify_axon,
forward_fn_return_annotation,
):
tokens = [f"data{i}\n" for i in range(10)]

async def streamer(send):
for token in tokens:
await send(
{
"type": "http.response.body",
"body": token.encode(),
"more_body": True,
}
)
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def forward_fn(synapse: streaming_synapse_cls):
return synapse.create_streaming_response(token_streamer=streamer)

if forward_fn_return_annotation is not None:
forward_fn.__annotations__["return"] = forward_fn_return_annotation

axon.attach(forward_fn)

response = http_client.post_synapse(streaming_synapse_cls())
assert (response.status_code, response.text) == (200, "".join(tokens))
self.assert_headers(
response,
{
"content-type": "text/event-stream",
"name": "CustomStreamingSynapse",
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
},
)

0 comments on commit 7fc1e4b

Please sign in to comment.