Skip to content

Commit

Permalink
feat(core): add EventSource protocol (#150)
Browse files Browse the repository at this point in the history
* ➕ add sse-starlette dependency

➕ add httpx-sse dependency for unit tests

* 🔨 add py.typed

* ♻️ refactor StreamingResponse class

use sse_starlette.EventSourceResponse instead of fastapi.responses.StreamingResponse

send 500 error event when chain execution fails

* ⚡ update base callback handlers

use ServerSentEvent to construct message

* 🐛 fix ensure_bytes bugs

* ✅ update unit test

* 🐛 fix streaming_response bug

* 💡 update docstrings
  • Loading branch information
ajndkr authored Nov 16, 2023
1 parent 90bd4bf commit 3c67a51
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 27 deletions.
27 changes: 18 additions & 9 deletions lanarky/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
from abc import abstractmethod
from typing import Any

from fastapi import WebSocket
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.globals import get_llm_cache
from sse_starlette.sse import ServerSentEvent, ensure_bytes
from starlette.types import Message, Send

from lanarky.schemas import StreamingJSONResponse, WebsocketResponse
Expand Down Expand Up @@ -36,6 +36,11 @@ def _construct_message(self, content: Any) -> Any: # pragma: no cover
"""Constructs a Message from a string."""
pass

@abstractmethod
def _construct_chunk(self, content: Any) -> Any:
"""Constructs a message chunk"""
pass


class AsyncStreamingResponseCallback(AsyncLanarkyCallback):
"""Async Callback handler for StreamingResponse."""
Expand All @@ -47,12 +52,17 @@ def __init__(self, send: Send, **kwargs: Any) -> None:

def _construct_message(self, content: str) -> Message:
"""Constructs a Message from a string."""
chunk = self._construct_chunk(content)
return {
"type": "http.response.body",
"body": content.encode("utf-8"),
"body": ensure_bytes(chunk, None),
"more_body": True,
}

def _construct_chunk(self, content: str) -> ServerSentEvent:
"""Constructs a message chunk"""
return ServerSentEvent(data=content)


class AsyncWebsocketCallback(AsyncLanarkyCallback):
"""Async Callback handler for WebsocketConnection."""
Expand All @@ -75,14 +85,13 @@ class AsyncStreamingJSONResponseCallback(AsyncStreamingResponseCallback):

def _construct_message(self, content: StreamingJSONResponse) -> Message:
"""Constructs a Message from a dictionary."""
chunk = self._construct_chunk(content)
return {
"type": "http.response.body",
"body": json.dumps(
content.model_dump(),
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8"),
"body": ensure_bytes(chunk, None),
"more_body": True,
}

def _construct_chunk(self, content: StreamingJSONResponse) -> ServerSentEvent:
"""Constructs a message chunk"""
return ServerSentEvent(data=content.model_dump())
Empty file added lanarky/py.typed
Empty file.
19 changes: 12 additions & 7 deletions lanarky/responses/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from functools import partial
from typing import Any, Awaitable, Callable, Optional, Union

from fastapi.responses import StreamingResponse as _StreamingResponse
from langchain.chains.base import Chain
from sse_starlette.sse import EventSourceResponse, ServerSentEvent, ensure_bytes
from starlette.background import BackgroundTask
from starlette.types import Receive, Send

Expand All @@ -16,7 +16,7 @@
logger = logging.getLogger(__name__)


class StreamingResponse(_StreamingResponse):
class StreamingResponse(EventSourceResponse):
"""StreamingResponse class wrapper for langchain chains."""

def __init__(
Expand Down Expand Up @@ -56,18 +56,23 @@ async def stream_response(self, send: Send) -> None:
try:
outputs = await self.chain_executor(send)
if self.background is not None:
self.background.kwargs["outputs"] = outputs
self.background.kwargs.update({"outputs": outputs})
except Exception as e:
logger.error(f"chain execution error: {e}")
if self.background is not None:
self.background.kwargs["outputs"] = str(e)
self.background.kwargs.update({"outputs": {}, "error": e})
# FIXME: use enum instead of hardcoding event name
chunk = ServerSentEvent(
data=dict(status_code=500, detail="Internal Server Error"),
event="error",
)
await send(
{
"type": "http.response.body",
"body": str(e).encode(self.charset),
"more_body": False,
"body": ensure_bytes(chunk, None),
"more_body": True,
}
)
return

await send({"type": "http.response.body", "body": b"", "more_body": False})

Expand Down
27 changes: 26 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 11 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ packages = [{include = "lanarky"}]
[tool.poetry.dependencies]
python = "^3.9"
fastapi = ">=0.97.0"
langchain = ">=0.0.200"
pydantic = ">=1,<3"
sse-starlette = "^1.6.5"
langchain = ">=0.0.200"
redis = {version = "^4.5.5", optional = true}
gptcache = {version = "^0.1.31", optional = true}
openai = {version = "^1", optional = true}
Expand All @@ -23,26 +24,27 @@ tiktoken = {version = "^0.4.0", optional = true}
pre-commit = "^3.3.3"
uvicorn = {extras = ["standard"], version = "<1"}

[tool.poetry.group.docs.dependencies]
furo = "^2023.5.20"
sphinx-autobuild = "^2021.3.14"
myst-parser = "^1.0.0"
sphinx-copybutton = "^0.5.2"
autodoc-pydantic = "^1.8.0"
toml = "^0.10.2"

[tool.poetry.group.tests.dependencies]
pytest = "^7.3.2"
pytest-cov = "^4.1.0"
pytest-asyncio = "^0.21.0"
coveralls = "^3.3.1"
httpx = "^0.24.1"
httpx-sse = "^0.3.1"

[tool.poetry.extras]
openai = ["openai", "tiktoken"]
redis = ["redis"]
gptcache = ["gptcache"]

[tool.poetry.group.docs.dependencies]
furo = "^2023.5.20"
sphinx-autobuild = "^2021.3.14"
myst-parser = "^1.0.0"
sphinx-copybutton = "^0.5.2"
autodoc-pydantic = "^1.8.0"
toml = "^0.10.2"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
6 changes: 5 additions & 1 deletion tests/responses/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ async def test_stream_response_error(

await streaming_response.stream_response(send)

assert background.kwargs["outputs"] == "Something went wrong"
assert background.kwargs["outputs"] == {}
assert (
isinstance(background.kwargs["error"], Exception)
and str(background.kwargs["error"]) == "Something went wrong"
)


@pytest.mark.asyncio
Expand Down

0 comments on commit 3c67a51

Please sign in to comment.