Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): add EventSource protocol #150

Merged
merged 8 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions lanarky/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
from abc import abstractmethod
from typing import Any

from fastapi import WebSocket
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.globals import get_llm_cache
from sse_starlette.sse import ServerSentEvent, ensure_bytes
from starlette.types import Message, Send

from lanarky.schemas import StreamingJSONResponse, WebsocketResponse
Expand Down Expand Up @@ -36,6 +36,11 @@ def _construct_message(self, content: Any) -> Any: # pragma: no cover
"""Constructs a Message from a string."""
pass

@abstractmethod
def _construct_chunk(self, content: Any) -> Any:
"""Constructs a message chunk"""
pass


class AsyncStreamingResponseCallback(AsyncLanarkyCallback):
"""Async Callback handler for StreamingResponse."""
Expand All @@ -47,12 +52,17 @@ def __init__(self, send: Send, **kwargs: Any) -> None:

def _construct_message(self, content: str) -> Message:
"""Constructs a Message from a string."""
chunk = self._construct_chunk(content)
return {
"type": "http.response.body",
"body": content.encode("utf-8"),
"body": ensure_bytes(chunk, None),
"more_body": True,
}

def _construct_chunk(self, content: str) -> ServerSentEvent:
"""Constructs a message chunk"""
return ServerSentEvent(data=content)


class AsyncWebsocketCallback(AsyncLanarkyCallback):
"""Async Callback handler for WebsocketConnection."""
Expand All @@ -75,14 +85,13 @@ class AsyncStreamingJSONResponseCallback(AsyncStreamingResponseCallback):

def _construct_message(self, content: StreamingJSONResponse) -> Message:
"""Constructs a Message from a dictionary."""
chunk = self._construct_chunk(content)
return {
"type": "http.response.body",
"body": json.dumps(
content.model_dump(),
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8"),
"body": ensure_bytes(chunk, None),
"more_body": True,
}

def _construct_chunk(self, content: StreamingJSONResponse) -> ServerSentEvent:
"""Constructs a message chunk"""
return ServerSentEvent(data=content.model_dump())
Empty file added lanarky/py.typed
Empty file.
19 changes: 12 additions & 7 deletions lanarky/responses/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from functools import partial
from typing import Any, Awaitable, Callable, Optional, Union

from fastapi.responses import StreamingResponse as _StreamingResponse
from langchain.chains.base import Chain
from sse_starlette.sse import EventSourceResponse, ServerSentEvent, ensure_bytes
from starlette.background import BackgroundTask
from starlette.types import Receive, Send

Expand All @@ -16,7 +16,7 @@
logger = logging.getLogger(__name__)


class StreamingResponse(_StreamingResponse):
class StreamingResponse(EventSourceResponse):
"""StreamingResponse class wrapper for langchain chains."""

def __init__(
Expand Down Expand Up @@ -56,18 +56,23 @@ async def stream_response(self, send: Send) -> None:
try:
outputs = await self.chain_executor(send)
if self.background is not None:
self.background.kwargs["outputs"] = outputs
self.background.kwargs.update({"outputs": outputs})
except Exception as e:
logger.error(f"chain execution error: {e}")
if self.background is not None:
self.background.kwargs["outputs"] = str(e)
self.background.kwargs.update({"outputs": {}, "error": e})
# FIXME: use enum instead of hardcoding event name
chunk = ServerSentEvent(
data=dict(status_code=500, detail="Internal Server Error"),
event="error",
)
await send(
{
"type": "http.response.body",
"body": str(e).encode(self.charset),
"more_body": False,
"body": ensure_bytes(chunk, None),
"more_body": True,
}
)
return

await send({"type": "http.response.body", "body": b"", "more_body": False})

Expand Down
27 changes: 26 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 11 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ packages = [{include = "lanarky"}]
[tool.poetry.dependencies]
python = "^3.9"
fastapi = ">=0.97.0"
langchain = ">=0.0.200"
pydantic = ">=1,<3"
sse-starlette = "^1.6.5"
langchain = ">=0.0.200"
redis = {version = "^4.5.5", optional = true}
gptcache = {version = "^0.1.31", optional = true}
openai = {version = "^1", optional = true}
Expand All @@ -23,26 +24,27 @@ tiktoken = {version = "^0.4.0", optional = true}
pre-commit = "^3.3.3"
uvicorn = {extras = ["standard"], version = "<1"}

[tool.poetry.group.docs.dependencies]
furo = "^2023.5.20"
sphinx-autobuild = "^2021.3.14"
myst-parser = "^1.0.0"
sphinx-copybutton = "^0.5.2"
autodoc-pydantic = "^1.8.0"
toml = "^0.10.2"

[tool.poetry.group.tests.dependencies]
pytest = "^7.3.2"
pytest-cov = "^4.1.0"
pytest-asyncio = "^0.21.0"
coveralls = "^3.3.1"
httpx = "^0.24.1"
httpx-sse = "^0.3.1"

[tool.poetry.extras]
openai = ["openai", "tiktoken"]
redis = ["redis"]
gptcache = ["gptcache"]

[tool.poetry.group.docs.dependencies]
furo = "^2023.5.20"
sphinx-autobuild = "^2021.3.14"
myst-parser = "^1.0.0"
sphinx-copybutton = "^0.5.2"
autodoc-pydantic = "^1.8.0"
toml = "^0.10.2"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
6 changes: 5 additions & 1 deletion tests/responses/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ async def test_stream_response_error(

await streaming_response.stream_response(send)

assert background.kwargs["outputs"] == "Something went wrong"
assert background.kwargs["outputs"] == {}
assert (
isinstance(background.kwargs["error"], Exception)
and str(background.kwargs["error"]) == "Something went wrong"
)


@pytest.mark.asyncio
Expand Down