Skip to content

Commit

Permalink
set streaming response headers
Browse files Browse the repository at this point in the history
  • Loading branch information
mjurbanski-reef committed Jul 24, 2024
1 parent 2d8839b commit d279f7f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 14 deletions.
47 changes: 36 additions & 11 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,23 @@ async def endpoint(*args, **kwargs):
return await self.middleware_cls.synapse_to_response(
synapse=response, start_time=start_time
)
else: # e.g. BTStreamingResponse
return response
else:
response_synapse = getattr(response, "synapse", None)
if response_synapse is None:
warnings.warn(
"The response synapse is None. The input synapse will be used as the response synapse. "
"Reliance on forward_fn modifying input synapse as a side-effects is deprecated. "
"Explicitly set `synapse` on response object instead.",
DeprecationWarning,
)
# Replace with `return response` in next major version
response_synapse = args[0]

return await self.middleware_cls.synapse_to_response(
synapse=response_synapse,
start_time=start_time,
response_override=response,
)

return_annotation = forward_sig.return_annotation

Expand Down Expand Up @@ -1439,14 +1454,21 @@ async def run(

@classmethod
async def synapse_to_response(
cls, synapse: bittensor.Synapse, start_time: float
) -> JSONResponse:
cls,
synapse: bittensor.Synapse,
start_time: float,
*,
response_override: Optional[Response] = None,
) -> Response:
"""
Converts the Synapse object into a JSON response with HTTP headers.
Args:
synapse (bittensor.Synapse): The Synapse object representing the request.
start_time (float): The timestamp when the request processing started.
synapse: The Synapse object representing the request.
start_time: The timestamp when the request processing started.
response_override:
Instead of serializing the synapse, mutate the provided response object.
This is only really useful for StreamingSynapse responses.
Returns:
Response: The final HTTP response, with updated headers, ready to be sent back to the client.
Expand All @@ -1465,11 +1487,14 @@ async def synapse_to_response(

synapse.axon.process_time = time.time() - start_time

serialized_synapse = await serialize_response(response_content=synapse)
response = JSONResponse(
status_code=synapse.axon.status_code,
content=serialized_synapse,
)
if response_override:
response = response_override
else:
serialized_synapse = await serialize_response(response_content=synapse)
response = JSONResponse(
status_code=synapse.axon.status_code,
content=serialized_synapse,
)

try:
updated_headers = synapse.to_headers()
Expand Down
14 changes: 12 additions & 2 deletions bittensor/stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

from aiohttp import ClientResponse
import bittensor

Expand Down Expand Up @@ -49,16 +51,24 @@ class BTStreamingResponse(_StreamingResponse):
provided by the subclass.
"""

def __init__(self, model: BTStreamingResponseModel, **kwargs):
def __init__(
self,
model: BTStreamingResponseModel,
*,
synapse: typing.Optional["StreamingSynapse"] = None,
**kwargs,
):
"""
Initializes the BTStreamingResponse with the given token streamer model.
Args:
model: A BTStreamingResponseModel instance containing the token streamer callable, which is responsible for generating the content of the response.
synapse: The response Synapse to be used to update the response headers etc.
**kwargs: Additional keyword arguments passed to the parent StreamingResponse class.
"""
super().__init__(content=iter(()), **kwargs)
self.token_streamer = model.token_streamer
self.synapse = synapse

async def stream_response(self, send: Send):
"""
Expand Down Expand Up @@ -139,4 +149,4 @@ def create_streaming_response(
"""
model_instance = BTStreamingResponseModel(token_streamer=token_streamer)

return self.BTStreamingResponse(model_instance)
return self.BTStreamingResponse(model_instance, synapse=self)
52 changes: 51 additions & 1 deletion tests/unit_tests/test_axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
import re
from dataclasses import dataclass

from typing import Any
from typing import Any, Optional
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, patch

# Third Party
import fastapi
import netaddr
import pydantic
import pytest
from starlette.requests import Request
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -532,6 +533,39 @@ def http_client(self, axon):
async def no_verify_fn(self, synapse):
return

class NonDeterministicHeaders(pydantic.BaseModel):
"""
Helper class to verify headers.
Size headers are non-determistic as for example, header_size depends on non-deterministic
processing-time value.
"""

bt_header_axon_process_time: float = pydantic.Field(gt=0, lt=30)
timeout: float = pydantic.Field(gt=0, lt=30)
header_size: int = pydantic.Field(None, gt=10, lt=400)
total_size: int = pydantic.Field(gt=100, lt=10000)
content_length: Optional[int] = pydantic.Field(
None, alias="content-length", gt=100, lt=10000
)

def assert_headers(self, response, expected_headers):
expected_headers = {
"bt_header_axon_status_code": "200",
"bt_header_axon_status_message": "Success",
**expected_headers,
}
headers = dict(response.headers)
non_deterministic_headers_names = {
field.alias or field_name
for field_name, field in self.NonDeterministicHeaders.model_fields.items()
}
non_deterministic_headers = {
field: headers.pop(field, None) for field in non_deterministic_headers_names
}
assert headers == expected_headers
self.NonDeterministicHeaders.model_validate(non_deterministic_headers)

async def test_unknown_path(self, http_client):
response = http_client.get("/no_such_path")
assert (response.status_code, response.json()) == (
Expand All @@ -557,6 +591,14 @@ async def test_ping__without_verification(self, http_client, axon):
assert response.status_code == 200
response_synapse = Synapse(**response.json())
assert response_synapse.axon.status_code == 200
self.assert_headers(
response,
{
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
"content-type": "application/json",
"name": "Synapse",
},
)

@pytest.fixture
def custom_synapse_cls(self):
Expand Down Expand Up @@ -664,3 +706,11 @@ async def forward_fn(synapse: streaming_synapse_cls):

response = http_client.post_synapse(streaming_synapse_cls())
assert (response.status_code, response.text) == (200, "".join(tokens))
self.assert_headers(
response,
{
"content-type": "text/event-stream",
"name": "CustomStreamingSynapse",
"computed_body_hash": "a7ffc6f8bf1ed76651c14756a061d662f580ff4de43b49fa82d80a4b80f8434a",
},
)

0 comments on commit d279f7f

Please sign in to comment.