Skip to content

Commit

Permalink
feat: add streaming JSON response (#60)
Browse files Browse the repository at this point in the history
* ✨ add new register function

* ✨ add new response schemas

* ✨ add AsyncStreamingJSONResponseCallback

* ✨ add additional streaming json callback handlers

* ✨ add get_streaming_json_callback

* ✨ add StreamingJSONResponse

* ✅ add new unit tests

- ✅ update fixture scope

* 🔨 update examples

* 📝 update contributing

* ⚡ update register decorator to accept list of keys

* ♻️ refactor callbacks

💡 improve error message

* ⚡ add `as_json` parameter to StreamingResponse

🔥 remove StreamingJSONResponse

* 🔨 update examples

* 📝 update docs

* ✅ update unit test

* ⬆️ bump langchain version

* ✅ update conftest.py
  • Loading branch information
ajndkr authored May 23, 2023
1 parent 7d9cd02 commit 8a3aedc
Show file tree
Hide file tree
Showing 24 changed files with 362 additions and 245 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ poetry install
run the following commands:

```bash
pre-commit install
poetry run pre-commit install
```
8 changes: 8 additions & 0 deletions docs/advanced/custom_callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ Let's first create a custom chain called ``ConversationalRetrievalWithSourcesCha

.. code-block:: python
import re
from typing import Any, Dict, List, Optional
from langchain.chains import ConversationalRetrievalChain
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
class ConversationalRetrievalWithSourcesChain(ConversationalRetrievalChain):
"""Chain for chatting with sources over documents."""
Expand Down
2 changes: 1 addition & 1 deletion docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Langchain Support
-----------------

- Token streaming over HTTP and Websocket
- Supports output streaming over HTTP and Websocket
- Supports multiple Chains and Agents

Gradio Testing
Expand Down
10 changes: 5 additions & 5 deletions docs/lanarky/lanarky.callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ lanarky.callbacks.llm module
:undoc-members:
:show-inheritance:

lanarky.callbacks.qa\_with\_sources module
------------------------------------------
lanarky.callbacks.retrieval\_qa module
--------------------------------------

.. automodule:: lanarky.callbacks.qa_with_sources
.. automodule:: lanarky.callbacks.retrieval_qa
:members:
:undoc-members:
:show-inheritance:

lanarky.callbacks.retrieval\_qa module
lanarky.callbacks.agents module
--------------------------------------

.. automodule:: lanarky.callbacks.retrieval_qa
.. automodule:: lanarky.callbacks.agents
:members:
:undoc-members:
:show-inheritance:
10 changes: 10 additions & 0 deletions examples/app/conversation_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ async def chat(
)


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
chain: ConversationChain = Depends(conversation_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, as_json=True, media_type="text/event-stream"
)


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
Expand Down
10 changes: 10 additions & 0 deletions examples/app/conversational_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,13 @@ async def chat(
"chat_history": [(human, ai) for human, ai in request.history],
}
return StreamingResponse.from_chain(chain, inputs, media_type="text/event-stream")


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
chain: ConversationalRetrievalChain = Depends(conversational_retrieval_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, as_json=True, media_type="text/event-stream"
)
10 changes: 10 additions & 0 deletions examples/app/retrieval_qa_w_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ async def chat(
)


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, as_json=True, media_type="text/event-stream"
)


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
Expand Down
10 changes: 10 additions & 0 deletions examples/app/zero_shot_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ async def chat(
)


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
agent: AgentExecutor = Depends(zero_shot_agent),
) -> StreamingResponse:
return StreamingResponse.from_chain(
agent, request.query, as_json=True, media_type="text/event-stream"
)


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
Expand Down
84 changes: 37 additions & 47 deletions lanarky/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,22 @@
from langchain.chains.base import Chain

from lanarky.register import STREAMING_CALLBACKS, WEBSOCKET_CALLBACKS

from .agents import AsyncAgentsStreamingCallback, AsyncAgentsWebsocketCallback
from .base import AsyncStreamingResponseCallback, AsyncWebsocketCallback
from .llm import (
AsyncConversationChainStreamingCallback,
AsyncConversationChainWebsocketCallback,
AsyncLLMChainStreamingCallback,
AsyncLLMChainWebsocketCallback,
)
from .qa_with_sources import (
AsyncConversationalRetrievalChainStreamingCallback,
AsyncConversationalRetrievalChainWebsocketCallback,
AsyncQAWithSourcesChainStreamingCallback,
AsyncQAWithSourcesChainWebsocketCallback,
AsyncRetrievalQAWithSourcesChainStreamingCallback,
AsyncRetrievalQAWithSourcesChainWebsocketCallback,
AsyncVectorDBQAWithSourcesChainStreamingCallback,
AsyncVectorDBQAWithSourcesChainWebsocketCallback,
)
from .retrieval_qa import (
AsyncRetrievalQAStreamingCallback,
AsyncRetrievalQAWebsocketCallback,
AsyncVectorDBQAStreamingCallback,
AsyncVectorDBQAWebsocketCallback,
from lanarky.register import (
STREAMING_CALLBACKS,
STREAMING_JSON_CALLBACKS,
WEBSOCKET_CALLBACKS,
)

__all__ = [
"AsyncLLMChainStreamingCallback",
"AsyncLLMChainWebsocketCallback",
"AsyncConversationChainStreamingCallback",
"AsyncConversationChainWebsocketCallback",
"AsyncRetrievalQAStreamingCallback",
"AsyncRetrievalQAWebsocketCallback",
"AsyncVectorDBQAStreamingCallback",
"AsyncVectorDBQAWebsocketCallback",
"AsyncQAWithSourcesChainStreamingCallback",
"AsyncQAWithSourcesChainWebsocketCallback",
"AsyncRetrievalQAWithSourcesChainStreamingCallback",
"AsyncRetrievalQAWithSourcesChainWebsocketCallback",
"AsyncVectorDBQAWithSourcesChainStreamingCallback",
"AsyncVectorDBQAWithSourcesChainWebsocketCallback",
"AsyncConversationalRetrievalChainStreamingCallback",
"AsyncConversationalRetrievalChainWebsocketCallback",
"AsyncAgentsStreamingCallback",
"AsyncAgentsWebsocketCallback",
]
from .agents import * # noqa: F401, F403
from .base import (
AsyncStreamingJSONResponseCallback,
AsyncStreamingResponseCallback,
AsyncWebsocketCallback,
)
from .llm import * # noqa: F401, F403
from .retrieval_qa import * # noqa: F401, F403

ERROR_MESSAGE = """Error! Chain type '{chain_type}' is not currently supported by '{callable_name}'.
Available chain types: {chain_types}
To use a custom chain type, you must register a new callback handler.
See the documentation for more details: https://lanarky.readthedocs.io/en/latest/advanced/custom_callbacks.html
Expand All @@ -66,7 +34,9 @@ def get_streaming_callback(
except KeyError:
raise KeyError(
ERROR_MESSAGE.format(
chain_type=chain_type, callable_name="AsyncStreamingResponseCallback"
chain_type=chain_type,
callable_name="AsyncStreamingResponseCallback",
chain_types="\n".join(list(STREAMING_CALLBACKS.keys())),
)
)

Expand All @@ -80,6 +50,26 @@ def get_websocket_callback(chain: Chain, *args, **kwargs) -> AsyncWebsocketCallb
except KeyError:
raise KeyError(
ERROR_MESSAGE.format(
chain_type=chain_type, callable_name="AsyncWebsocketCallback"
chain_type=chain_type,
callable_name="AsyncWebsocketCallback",
chain_types="\n".join(list(WEBSOCKET_CALLBACKS.keys())),
)
)


def get_streaming_json_callback(
chain: Chain, *args, **kwargs
) -> AsyncStreamingJSONResponseCallback:
"""Get the streaming JSON callback for the given chain type."""
chain_type = chain.__class__.__name__
try:
callback = STREAMING_JSON_CALLBACKS[chain_type]
return callback(*args, **kwargs)
except KeyError:
raise KeyError(
ERROR_MESSAGE.format(
chain_type=chain_type,
callable_name="AsyncStreamingJSONResponseCallback",
chain_types="\n".join(list(STREAMING_JSON_CALLBACKS.keys())),
)
)
34 changes: 28 additions & 6 deletions lanarky/callbacks/agents.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from typing import Any, Dict, List
from typing import Any

from lanarky.register import register_streaming_callback, register_websocket_callback
from lanarky.register import (
register_streaming_callback,
register_streaming_json_callback,
register_websocket_callback,
)
from lanarky.schemas import StreamingJSONResponse

from .base import AsyncLanarkyCallback
from .llm import AsyncLLMChainStreamingCallback, AsyncLLMChainWebsocketCallback
from .llm import (
AsyncLLMChainStreamingCallback,
AsyncLLMChainStreamingJSONCallback,
AsyncLLMChainWebsocketCallback,
)


class AsyncAgentsLanarkyCallback(AsyncLanarkyCallback):
Expand All @@ -12,12 +21,12 @@ class AsyncAgentsLanarkyCallback(AsyncLanarkyCallback):
Adapted from `langchain/callbacks/streaming_stdout_final_only.py <https://github.com/hwchase17/langchain/blob/master/langchain/callbacks/streaming_stdout_final_only.py>`_
"""

answer_prefix_tokens: List[str] = ["Final", " Answer", ":"]
last_tokens: List[str] = [""] * len(answer_prefix_tokens)
answer_prefix_tokens: list[str] = ["Final", " Answer", ":"]
last_tokens: list[str] = [""] * len(answer_prefix_tokens)
answer_reached: bool = False

async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
self.last_tokens = [""] * len(self.answer_prefix_tokens)
Expand Down Expand Up @@ -59,3 +68,16 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if self._check_if_answer_reached(token):
message = self._construct_message(token)
await self.websocket.send_json(message)


@register_streaming_json_callback("AgentExecutor")
class AsyncAgentsStreamingJSONCallback(
AsyncAgentsLanarkyCallback, AsyncLLMChainStreamingJSONCallback
):
"""AsyncStreamingJSONCallback handler for AgentExecutor."""

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
if self._check_if_answer_reached(token):
message = self._construct_message(StreamingJSONResponse(token=token))
await self.send(message)
37 changes: 29 additions & 8 deletions lanarky/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from abc import abstractmethod
from typing import Any

Expand All @@ -6,7 +7,7 @@
from pydantic import BaseModel, Field
from starlette.types import Message, Send

from lanarky.schemas import WebsocketResponse
from lanarky.schemas import StreamingJSONResponse, WebsocketResponse


class AsyncLanarkyCallback(AsyncCallbackHandler, BaseModel):
Expand All @@ -21,31 +22,51 @@ class Config:
arbitrary_types_allowed = True

@abstractmethod
def _construct_message(self, message: str) -> Any: # pragma: no cover
def _construct_message(self, content: Any) -> Any: # pragma: no cover
"""Construct a Message from a string."""
pass


class AsyncStreamingResponseCallback(AsyncLanarkyCallback):
"""Async Callback handler for FastAPI StreamingResponse."""
"""Async Callback handler for StreamingResponse."""

send: Send = Field(...)

def _construct_message(self, message_str: str) -> Message:
def _construct_message(self, content: str) -> Message:
"""Construct a Message from a string."""
return {
"type": "http.response.body",
"body": message_str.encode("utf-8"),
"body": content.encode("utf-8"),
"more_body": True,
}


class AsyncWebsocketCallback(AsyncLanarkyCallback):
"""Async Callback handler for FastAPI websocket connection."""
"""Async Callback handler for WebsocketConnection."""

websocket: WebSocket = Field(...)
response: WebsocketResponse = Field(...)

def _construct_message(self, message_str: str) -> dict:
def _construct_message(self, content: str) -> dict:
"""Construct a WebsocketResponse from a string."""
return {**self.response.dict(), **{"message": message_str.encode("utf-8")}}
return {**self.response.dict(), **{"message": content.encode("utf-8")}}


class AsyncStreamingJSONResponseCallback(AsyncStreamingResponseCallback):
"""Async Callback handler for StreamingJSONResponse."""

send: Send = Field(...)

def _construct_message(self, content: StreamingJSONResponse) -> Message:
"""Construct a Message from a dictionary."""
return {
"type": "http.response.body",
"body": json.dumps(
content.dict(),
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8"),
"more_body": True,
}
Loading

0 comments on commit 8a3aedc

Please sign in to comment.