diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c6673f2..3a68312 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,5 +27,5 @@ poetry install run the following commands: ```bash -pre-commit install +poetry run pre-commit install ``` diff --git a/docs/advanced/custom_callbacks.rst b/docs/advanced/custom_callbacks.rst index 005c892..65167bf 100644 --- a/docs/advanced/custom_callbacks.rst +++ b/docs/advanced/custom_callbacks.rst @@ -10,6 +10,14 @@ Let's first create a custom chain called ``ConversationalRetrievalWithSourcesCha .. code-block:: python + import re + from typing import Any, Dict, List, Optional + from langchain.chains import ConversationalRetrievalChain + from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + ) + class ConversationalRetrievalWithSourcesChain(ConversationalRetrievalChain): """Chain for chatting with sources over documents.""" diff --git a/docs/features.rst b/docs/features.rst index f1e20ea..19278d1 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -4,7 +4,7 @@ Langchain Support ----------------- -- Token streaming over HTTP and Websocket +- Supports output streaming over HTTP and Websocket - Supports multiple Chains and Agents Gradio Testing diff --git a/docs/lanarky/lanarky.callbacks.rst b/docs/lanarky/lanarky.callbacks.rst index 8f9c9ba..06ee90a 100644 --- a/docs/lanarky/lanarky.callbacks.rst +++ b/docs/lanarky/lanarky.callbacks.rst @@ -25,18 +25,18 @@ lanarky.callbacks.llm module :undoc-members: :show-inheritance: -lanarky.callbacks.qa\_with\_sources module ------------------------------------------- +lanarky.callbacks.retrieval\_qa module +-------------------------------------- -.. automodule:: lanarky.callbacks.qa_with_sources +.. automodule:: lanarky.callbacks.retrieval_qa :members: :undoc-members: :show-inheritance: -lanarky.callbacks.retrieval\_qa module +lanarky.callbacks.agents module -------------------------------------- -.. automodule:: lanarky.callbacks.retrieval_qa +.. automodule:: lanarky.callbacks.agents :members: :undoc-members: :show-inheritance: diff --git a/examples/app/conversation_chain.py b/examples/app/conversation_chain.py index 8fad391..ea2756f 100644 --- a/examples/app/conversation_chain.py +++ b/examples/app/conversation_chain.py @@ -49,6 +49,16 @@ async def chat( ) +@app.post("/chat_json") +async def chat_json( + request: QueryRequest, + chain: ConversationChain = Depends(conversation_chain), +) -> StreamingResponse: + return StreamingResponse.from_chain( + chain, request.query, as_json=True, media_type="text/event-stream" + ) + + @app.get("/") async def get(request: Request): return templates.TemplateResponse("index.html", {"request": request}) diff --git a/examples/app/conversational_retrieval.py b/examples/app/conversational_retrieval.py index b367ef9..d6eff4e 100644 --- a/examples/app/conversational_retrieval.py +++ b/examples/app/conversational_retrieval.py @@ -80,3 +80,13 @@ async def chat( "chat_history": [(human, ai) for human, ai in request.history], } return StreamingResponse.from_chain(chain, inputs, media_type="text/event-stream") + + +@app.post("/chat_json") +async def chat_json( + request: QueryRequest, + chain: ConversationalRetrievalChain = Depends(conversational_retrieval_chain), +) -> StreamingResponse: + return StreamingResponse.from_chain( + chain, request.query, as_json=True, media_type="text/event-stream" + ) diff --git a/examples/app/retrieval_qa_w_sources.py b/examples/app/retrieval_qa_w_sources.py index 8611142..e8583d9 100644 --- a/examples/app/retrieval_qa_w_sources.py +++ b/examples/app/retrieval_qa_w_sources.py @@ -62,6 +62,16 @@ async def chat( ) +@app.post("/chat_json") +async def chat_json( + request: QueryRequest, + chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain), +) -> StreamingResponse: + return StreamingResponse.from_chain( + chain, request.query, as_json=True, media_type="text/event-stream" + ) + + @app.get("/") async def get(request: Request): return templates.TemplateResponse("index.html", {"request": request}) diff --git a/examples/app/zero_shot_agent.py b/examples/app/zero_shot_agent.py index 527f31f..bff581a 100644 --- a/examples/app/zero_shot_agent.py +++ b/examples/app/zero_shot_agent.py @@ -51,6 +51,16 @@ async def chat( ) +@app.post("/chat_json") +async def chat_json( + request: QueryRequest, + agent: AgentExecutor = Depends(zero_shot_agent), +) -> StreamingResponse: + return StreamingResponse.from_chain( + agent, request.query, as_json=True, media_type="text/event-stream" + ) + + @app.get("/") async def get(request: Request): return templates.TemplateResponse("index.html", {"request": request}) diff --git a/lanarky/callbacks/__init__.py b/lanarky/callbacks/__init__.py index 7323d6a..d0f84d1 100644 --- a/lanarky/callbacks/__init__.py +++ b/lanarky/callbacks/__init__.py @@ -1,54 +1,22 @@ from langchain.chains.base import Chain -from lanarky.register import STREAMING_CALLBACKS, WEBSOCKET_CALLBACKS - -from .agents import AsyncAgentsStreamingCallback, AsyncAgentsWebsocketCallback -from .base import AsyncStreamingResponseCallback, AsyncWebsocketCallback -from .llm import ( - AsyncConversationChainStreamingCallback, - AsyncConversationChainWebsocketCallback, - AsyncLLMChainStreamingCallback, - AsyncLLMChainWebsocketCallback, -) -from .qa_with_sources import ( - AsyncConversationalRetrievalChainStreamingCallback, - AsyncConversationalRetrievalChainWebsocketCallback, - AsyncQAWithSourcesChainStreamingCallback, - AsyncQAWithSourcesChainWebsocketCallback, - AsyncRetrievalQAWithSourcesChainStreamingCallback, - AsyncRetrievalQAWithSourcesChainWebsocketCallback, - AsyncVectorDBQAWithSourcesChainStreamingCallback, - AsyncVectorDBQAWithSourcesChainWebsocketCallback, -) -from .retrieval_qa import ( - AsyncRetrievalQAStreamingCallback, - AsyncRetrievalQAWebsocketCallback, - AsyncVectorDBQAStreamingCallback, - AsyncVectorDBQAWebsocketCallback, +from lanarky.register import ( + STREAMING_CALLBACKS, + STREAMING_JSON_CALLBACKS, + WEBSOCKET_CALLBACKS, ) -__all__ = [ - "AsyncLLMChainStreamingCallback", - "AsyncLLMChainWebsocketCallback", - "AsyncConversationChainStreamingCallback", - "AsyncConversationChainWebsocketCallback", - "AsyncRetrievalQAStreamingCallback", - "AsyncRetrievalQAWebsocketCallback", - "AsyncVectorDBQAStreamingCallback", - "AsyncVectorDBQAWebsocketCallback", - "AsyncQAWithSourcesChainStreamingCallback", - "AsyncQAWithSourcesChainWebsocketCallback", - "AsyncRetrievalQAWithSourcesChainStreamingCallback", - "AsyncRetrievalQAWithSourcesChainWebsocketCallback", - "AsyncVectorDBQAWithSourcesChainStreamingCallback", - "AsyncVectorDBQAWithSourcesChainWebsocketCallback", - "AsyncConversationalRetrievalChainStreamingCallback", - "AsyncConversationalRetrievalChainWebsocketCallback", - "AsyncAgentsStreamingCallback", - "AsyncAgentsWebsocketCallback", -] +from .agents import * # noqa: F401, F403 +from .base import ( + AsyncStreamingJSONResponseCallback, + AsyncStreamingResponseCallback, + AsyncWebsocketCallback, +) +from .llm import * # noqa: F401, F403 +from .retrieval_qa import * # noqa: F401, F403 ERROR_MESSAGE = """Error! Chain type '{chain_type}' is not currently supported by '{callable_name}'. +Available chain types: {chain_types} To use a custom chain type, you must register a new callback handler. See the documentation for more details: https://lanarky.readthedocs.io/en/latest/advanced/custom_callbacks.html @@ -66,7 +34,9 @@ def get_streaming_callback( except KeyError: raise KeyError( ERROR_MESSAGE.format( - chain_type=chain_type, callable_name="AsyncStreamingResponseCallback" + chain_type=chain_type, + callable_name="AsyncStreamingResponseCallback", + chain_types="\n".join(list(STREAMING_CALLBACKS.keys())), ) ) @@ -80,6 +50,26 @@ def get_websocket_callback(chain: Chain, *args, **kwargs) -> AsyncWebsocketCallb except KeyError: raise KeyError( ERROR_MESSAGE.format( - chain_type=chain_type, callable_name="AsyncWebsocketCallback" + chain_type=chain_type, + callable_name="AsyncWebsocketCallback", + chain_types="\n".join(list(WEBSOCKET_CALLBACKS.keys())), + ) + ) + + +def get_streaming_json_callback( + chain: Chain, *args, **kwargs +) -> AsyncStreamingJSONResponseCallback: + """Get the streaming JSON callback for the given chain type.""" + chain_type = chain.__class__.__name__ + try: + callback = STREAMING_JSON_CALLBACKS[chain_type] + return callback(*args, **kwargs) + except KeyError: + raise KeyError( + ERROR_MESSAGE.format( + chain_type=chain_type, + callable_name="AsyncStreamingJSONResponseCallback", + chain_types="\n".join(list(STREAMING_JSON_CALLBACKS.keys())), ) ) diff --git a/lanarky/callbacks/agents.py b/lanarky/callbacks/agents.py index da93cb6..337648f 100644 --- a/lanarky/callbacks/agents.py +++ b/lanarky/callbacks/agents.py @@ -1,9 +1,18 @@ -from typing import Any, Dict, List +from typing import Any -from lanarky.register import register_streaming_callback, register_websocket_callback +from lanarky.register import ( + register_streaming_callback, + register_streaming_json_callback, + register_websocket_callback, +) +from lanarky.schemas import StreamingJSONResponse from .base import AsyncLanarkyCallback -from .llm import AsyncLLMChainStreamingCallback, AsyncLLMChainWebsocketCallback +from .llm import ( + AsyncLLMChainStreamingCallback, + AsyncLLMChainStreamingJSONCallback, + AsyncLLMChainWebsocketCallback, +) class AsyncAgentsLanarkyCallback(AsyncLanarkyCallback): @@ -12,12 +21,12 @@ class AsyncAgentsLanarkyCallback(AsyncLanarkyCallback): Adapted from `langchain/callbacks/streaming_stdout_final_only.py `_ """ - answer_prefix_tokens: List[str] = ["Final", " Answer", ":"] - last_tokens: List[str] = [""] * len(answer_prefix_tokens) + answer_prefix_tokens: list[str] = ["Final", " Answer", ":"] + last_tokens: list[str] = [""] * len(answer_prefix_tokens) answer_reached: bool = False async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: """Run when LLM starts running.""" self.last_tokens = [""] * len(self.answer_prefix_tokens) @@ -59,3 +68,16 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: if self._check_if_answer_reached(token): message = self._construct_message(token) await self.websocket.send_json(message) + + +@register_streaming_json_callback("AgentExecutor") +class AsyncAgentsStreamingJSONCallback( + AsyncAgentsLanarkyCallback, AsyncLLMChainStreamingJSONCallback +): + """AsyncStreamingJSONCallback handler for AgentExecutor.""" + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + if self._check_if_answer_reached(token): + message = self._construct_message(StreamingJSONResponse(token=token)) + await self.send(message) diff --git a/lanarky/callbacks/base.py b/lanarky/callbacks/base.py index 650f66f..cf085af 100644 --- a/lanarky/callbacks/base.py +++ b/lanarky/callbacks/base.py @@ -1,3 +1,4 @@ +import json from abc import abstractmethod from typing import Any @@ -6,7 +7,7 @@ from pydantic import BaseModel, Field from starlette.types import Message, Send -from lanarky.schemas import WebsocketResponse +from lanarky.schemas import StreamingJSONResponse, WebsocketResponse class AsyncLanarkyCallback(AsyncCallbackHandler, BaseModel): @@ -21,31 +22,51 @@ class Config: arbitrary_types_allowed = True @abstractmethod - def _construct_message(self, message: str) -> Any: # pragma: no cover + def _construct_message(self, content: Any) -> Any: # pragma: no cover """Construct a Message from a string.""" pass class AsyncStreamingResponseCallback(AsyncLanarkyCallback): - """Async Callback handler for FastAPI StreamingResponse.""" + """Async Callback handler for StreamingResponse.""" send: Send = Field(...) - def _construct_message(self, message_str: str) -> Message: + def _construct_message(self, content: str) -> Message: """Construct a Message from a string.""" return { "type": "http.response.body", - "body": message_str.encode("utf-8"), + "body": content.encode("utf-8"), "more_body": True, } class AsyncWebsocketCallback(AsyncLanarkyCallback): - """Async Callback handler for FastAPI websocket connection.""" + """Async Callback handler for WebsocketConnection.""" websocket: WebSocket = Field(...) response: WebsocketResponse = Field(...) - def _construct_message(self, message_str: str) -> dict: + def _construct_message(self, content: str) -> dict: """Construct a WebsocketResponse from a string.""" - return {**self.response.dict(), **{"message": message_str.encode("utf-8")}} + return {**self.response.dict(), **{"message": content.encode("utf-8")}} + + +class AsyncStreamingJSONResponseCallback(AsyncStreamingResponseCallback): + """Async Callback handler for StreamingJSONResponse.""" + + send: Send = Field(...) + + def _construct_message(self, content: StreamingJSONResponse) -> Message: + """Construct a Message from a dictionary.""" + return { + "type": "http.response.body", + "body": json.dumps( + content.dict(), + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ).encode("utf-8"), + "more_body": True, + } diff --git a/lanarky/callbacks/llm.py b/lanarky/callbacks/llm.py index ebbb009..58dd8b9 100644 --- a/lanarky/callbacks/llm.py +++ b/lanarky/callbacks/llm.py @@ -1,11 +1,22 @@ from typing import Any -from lanarky.register import register_streaming_callback, register_websocket_callback +from lanarky.register import ( + register_streaming_callback, + register_streaming_json_callback, + register_websocket_callback, +) +from lanarky.schemas import StreamingJSONResponse -from .base import AsyncStreamingResponseCallback, AsyncWebsocketCallback +from .base import ( + AsyncStreamingJSONResponseCallback, + AsyncStreamingResponseCallback, + AsyncWebsocketCallback, +) +SUPPORTED_CHAINS = ["LLMChain", "ConversationChain"] -@register_streaming_callback("LLMChain") + +@register_streaming_callback(SUPPORTED_CHAINS) class AsyncLLMChainStreamingCallback(AsyncStreamingResponseCallback): """AsyncStreamingResponseCallback handler for LLMChain.""" @@ -15,7 +26,7 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: await self.send(message) -@register_websocket_callback("LLMChain") +@register_websocket_callback(SUPPORTED_CHAINS) class AsyncLLMChainWebsocketCallback(AsyncWebsocketCallback): """AsyncWebsocketCallback handler for LLMChain.""" @@ -25,15 +36,11 @@ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: await self.websocket.send_json(message) -@register_streaming_callback("ConversationChain") -class AsyncConversationChainStreamingCallback(AsyncLLMChainStreamingCallback): - """AsyncStreamingResponseCallback handler for ConversationChain.""" - - pass - - -@register_websocket_callback("ConversationChain") -class AsyncConversationChainWebsocketCallback(AsyncLLMChainWebsocketCallback): - """AsyncWebsocketCallback handler for ConversationChain.""" +@register_streaming_json_callback(SUPPORTED_CHAINS) +class AsyncLLMChainStreamingJSONCallback(AsyncStreamingJSONResponseCallback): + """AsyncStreamingJSONResponseCallback handler for LLMChain.""" - pass + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + message = self._construct_message(StreamingJSONResponse(token=token)) + await self.send(message) diff --git a/lanarky/callbacks/qa_with_sources.py b/lanarky/callbacks/qa_with_sources.py deleted file mode 100644 index 5377c40..0000000 --- a/lanarky/callbacks/qa_with_sources.py +++ /dev/null @@ -1,96 +0,0 @@ -from lanarky.register import register_streaming_callback, register_websocket_callback - -from .retrieval_qa import ( - AsyncBaseRetrievalQAStreamingCallback, - AsyncBaseRetrievalQAWebsocketCallback, -) - - -@register_streaming_callback("BaseQAWithSources") -class AsyncBaseQAWithSourcesChainStreamingCallback( - AsyncBaseRetrievalQAStreamingCallback -): - """AsyncStreamingResponseCallback handler for BaseQAWithSources.""" - - pass - - -@register_websocket_callback("BaseQAWithSources") -class AsyncBaseQAWithSourcesChainWebsocketCallback( - AsyncBaseRetrievalQAWebsocketCallback -): - """AsyncWebsocketCallback handler for BaseQAWithSources.""" - - pass - - -@register_streaming_callback("QAWithSourcesChain") -class AsyncQAWithSourcesChainStreamingCallback( - AsyncBaseQAWithSourcesChainStreamingCallback -): - """AsyncStreamingResponseCallback handler for QAWithSourcesChain.""" - - pass - - -@register_websocket_callback("QAWithSourcesChain") -class AsyncQAWithSourcesChainWebsocketCallback( - AsyncBaseQAWithSourcesChainWebsocketCallback -): - """AsyncWebsocketCallback handler for QAWithSourcesChain.""" - - pass - - -@register_streaming_callback("VectorDBQAWithSourcesChain") -class AsyncVectorDBQAWithSourcesChainStreamingCallback( - AsyncBaseQAWithSourcesChainStreamingCallback -): - """AsyncStreamingResponseCallback handler for VectorDBQAWithSourcesChain.""" - - pass - - -@register_websocket_callback("VectorDBQAWithSourcesChain") -class AsyncVectorDBQAWithSourcesChainWebsocketCallback( - AsyncBaseQAWithSourcesChainWebsocketCallback -): - """AsyncWebsocketCallback handler for VectorDBQAWithSourcesChain.""" - - pass - - -@register_streaming_callback("RetrievalQAWithSourcesChain") -class AsyncRetrievalQAWithSourcesChainStreamingCallback( - AsyncBaseQAWithSourcesChainStreamingCallback -): - """AsyncStreamingResponseCallback handler for RetrievalQAWithSourcesChain.""" - - pass - - -@register_websocket_callback("RetrievalQAWithSourcesChain") -class AsyncRetrievalQAWithSourcesChainWebsocketCallback( - AsyncBaseQAWithSourcesChainWebsocketCallback -): - """AsyncWebsocketCallback handler for RetrievalQAWithSourcesChain.""" - - pass - - -@register_streaming_callback("ConversationalRetrievalChain") -class AsyncConversationalRetrievalChainStreamingCallback( - AsyncBaseQAWithSourcesChainStreamingCallback -): - """AsyncStreamingResponseCallback handler for ConversationalRetrievalChain.""" - - pass - - -@register_websocket_callback("ConversationalRetrievalChain") -class AsyncConversationalRetrievalChainWebsocketCallback( - AsyncBaseQAWithSourcesChainWebsocketCallback -): - """AsyncWebsocketCallback handler for ConversationalRetrievalChain.""" - - pass diff --git a/lanarky/callbacks/retrieval_qa.py b/lanarky/callbacks/retrieval_qa.py index 952bf92..158f6d5 100644 --- a/lanarky/callbacks/retrieval_qa.py +++ b/lanarky/callbacks/retrieval_qa.py @@ -1,78 +1,92 @@ -from typing import Any, Dict - -from lanarky.register import register_streaming_callback, register_websocket_callback - -from .llm import AsyncLLMChainStreamingCallback, AsyncLLMChainWebsocketCallback - +from typing import Any + +from lanarky.register import ( + register_streaming_callback, + register_streaming_json_callback, + register_websocket_callback, +) +from lanarky.schemas import BaseRetrievalQAStreamingJSONResponse + +from .llm import ( + AsyncLLMChainStreamingCallback, + AsyncLLMChainStreamingJSONCallback, + AsyncLLMChainWebsocketCallback, +) + +SUPPORTED_CHAINS = [ + "RetrievalQA", + "ConversationRetrievalQA", + "VectorDBQA", + "QAWithSourcesChain", + "VectorDBQAWithSourcesChain", + "RetrievalQAWithSourcesChain", + "ConversationalRetrievalChain", +] SOURCE_DOCUMENT_TEMPLATE = """ page content: {page_content} -source: {source} +{document_metadata} """ +@register_streaming_callback(SUPPORTED_CHAINS) class AsyncBaseRetrievalQAStreamingCallback(AsyncLLMChainStreamingCallback): """AsyncStreamingResponseCallback handler for BaseRetrievalQA.""" source_document_template: str = SOURCE_DOCUMENT_TEMPLATE - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + async def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" if "source_documents" in outputs: - message = self._construct_message("\n\nSOURCE DOCUMENTS: \n") + message = self._construct_message("\n\nSOURCE DOCUMENTS:\n") await self.send(message) for document in outputs["source_documents"]: + document_metadata = "\n".join( + [f"{k}: {v}" for k, v in document.metadata.items()] + ) message = self._construct_message( self.source_document_template.format( page_content=document.page_content, - source=document.metadata["source"], + document_metadata=document_metadata, ) ) await self.send(message) +@register_websocket_callback(SUPPORTED_CHAINS) class AsyncBaseRetrievalQAWebsocketCallback(AsyncLLMChainWebsocketCallback): """AsyncWebsocketCallback handler for BaseRetrievalQA.""" source_document_template: str = SOURCE_DOCUMENT_TEMPLATE - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + async def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" if "source_documents" in outputs: - message = self._construct_message("\n\nSOURCE DOCUMENTS: \n") + message = self._construct_message("\n\nSOURCE DOCUMENTS:\n") await self.websocket.send_json(message) for document in outputs["source_documents"]: + document_metadata = "\n".join( + [f"{k}: {v}" for k, v in document.metadata.items()] + ) message = self._construct_message( self.source_document_template.format( page_content=document.page_content, - source=document.metadata["source"], + document_metadata=document_metadata, ) ) await self.websocket.send_json(message) -@register_streaming_callback("RetrievalQA") -class AsyncRetrievalQAStreamingCallback(AsyncBaseRetrievalQAStreamingCallback): - """AsyncStreamingResponseCallback handler for RetrievalQA.""" - - pass - - -@register_streaming_callback("VectorDBQA") -class AsyncVectorDBQAStreamingCallback(AsyncBaseRetrievalQAStreamingCallback): - """AsyncStreamingResponseCallback handler for VectorDBQA.""" - - pass +@register_streaming_json_callback(SUPPORTED_CHAINS) +class AsyncBaseRetrievalQAStreamingJSONCallback(AsyncLLMChainStreamingJSONCallback): + """AsyncStreamingJSONResponseCallback handler for BaseRetrievalQA.""" - -@register_websocket_callback("RetrievalQA") -class AsyncRetrievalQAWebsocketCallback(AsyncBaseRetrievalQAWebsocketCallback): - """AsyncWebsocketCallback handler for RetrievalQA.""" - - pass - - -@register_websocket_callback("VectorDBQA") -class AsyncVectorDBQAWebsocketCallback(AsyncBaseRetrievalQAWebsocketCallback): - """AsyncWebsocketCallback handler for VectorDBQA.""" - - pass + async def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + if "source_documents" in outputs: + source_documents = [ + document.dict() for document in outputs["source_documents"] + ] + message = self._construct_message( + BaseRetrievalQAStreamingJSONResponse(source_documents=source_documents) + ) + await self.send(message) diff --git a/lanarky/register.py b/lanarky/register.py index 0b0538d..7dab34e 100644 --- a/lanarky/register.py +++ b/lanarky/register.py @@ -1,28 +1,41 @@ -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Optional, Union -def register(key: str, _registry: Dict[str, Tuple[Any, List[str]]]) -> Any: +def register( + key: Union[list[str], str], _registry: dict[str, tuple[Any, list[str]]] +) -> Any: """Add a class/function to a registry with required keyword arguments. ``_registry`` is a dictionary mapping from a key to a tuple of the class/function and a list of required keyword arguments, if keyword arguments are passed. Otherwise it is a dictionary mapping from a key to the class/function. + + Args: + key: key or list of keys to register the class/function under. + _registry: registry to add the class/function to. """ - def _register_cls(cls: Any, required_kwargs: Optional[List] = None) -> Any: - if key in _registry: - raise KeyError(f"{cls} already registered as {key}") - _registry[key] = cls if required_kwargs is None else (cls, required_kwargs) + def _register_cls(cls: Any, required_kwargs: Optional[list] = None) -> Any: + if isinstance(key, str): + keys = [key] + else: + keys = key + + for _key in keys: + if _key in _registry: + raise KeyError(f"{cls} already registered as {_key}") + _registry[_key] = cls if required_kwargs is None else (cls, required_kwargs) return cls return _register_cls -STREAMING_CALLBACKS: Dict[str, Any] = {} -WEBSOCKET_CALLBACKS: Dict[str, Any] = {} +STREAMING_CALLBACKS: dict[str, Any] = {} +WEBSOCKET_CALLBACKS: dict[str, Any] = {} +STREAMING_JSON_CALLBACKS: dict[str, Any] = {} -def register_streaming_callback(key: str) -> Callable: +def register_streaming_callback(key: Union[list[str], str]) -> Callable: """Register a streaming callback handler.""" def _register_cls(cls: Any) -> Callable: @@ -32,7 +45,7 @@ def _register_cls(cls: Any) -> Callable: return _register_cls -def register_websocket_callback(key: str) -> Callable: +def register_websocket_callback(key: Union[list[str], str]) -> Callable: """Register a websocket callback handler.""" def _register_cls(cls: Any) -> Callable: @@ -40,3 +53,13 @@ def _register_cls(cls: Any) -> Callable: return cls return _register_cls + + +def register_streaming_json_callback(key: Union[list[str], str]) -> Callable: + """Register an streaming json callback handler.""" + + def _register_cls(cls: Any) -> Callable: + register(key, STREAMING_JSON_CALLBACKS)(cls=cls) + return cls + + return _register_cls diff --git a/lanarky/responses/streaming.py b/lanarky/responses/streaming.py index cf586a6..6e5518c 100644 --- a/lanarky/responses/streaming.py +++ b/lanarky/responses/streaming.py @@ -11,7 +11,7 @@ from starlette.background import BackgroundTask from starlette.types import Send -from lanarky.callbacks import get_streaming_callback +from lanarky.callbacks import get_streaming_callback, get_streaming_json_callback class StreamingResponse(_StreamingResponse): @@ -56,12 +56,19 @@ async def stream_response(self, send: Send) -> None: @staticmethod def _create_chain_executor( - chain: Chain, inputs: Union[dict[str, Any], Any], **callback_kwargs + chain: Chain, + inputs: Union[dict[str, Any], Any], + as_json: bool = False, + **callback_kwargs, ) -> Callable[[Send], Awaitable[Any]]: + get_callback_fn = ( + get_streaming_json_callback if as_json else get_streaming_callback + ) + async def wrapper(send: Send): return await chain.acall( inputs=inputs, - callbacks=[get_streaming_callback(chain, send=send, **callback_kwargs)], + callbacks=[get_callback_fn(chain, send=send, **callback_kwargs)], ) return wrapper @@ -71,11 +78,14 @@ def from_chain( cls, chain: Chain, inputs: Union[dict[str, Any], Any], + as_json: bool = False, background: Optional[BackgroundTask] = None, callback_kwargs: dict[str, Any] = {}, **kwargs: Any, ) -> "StreamingResponse": - chain_executor = cls._create_chain_executor(chain, inputs, **callback_kwargs) + chain_executor = cls._create_chain_executor( + chain, inputs, as_json, **callback_kwargs + ) return cls( chain_executor=chain_executor, diff --git a/lanarky/schemas.py b/lanarky/schemas.py index 93a513b..e154c05 100644 --- a/lanarky/schemas.py +++ b/lanarky/schemas.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Union +from typing import Any, Union from pydantic import BaseModel @@ -29,3 +29,11 @@ class WebsocketResponse(BaseModel): class Config: use_enum_values = True + + +class StreamingJSONResponse(BaseModel): + token: str = "" + + +class BaseRetrievalQAStreamingJSONResponse(StreamingJSONResponse): + source_documents: list[dict[str, Any]] diff --git a/pyproject.toml b/pyproject.toml index 301ea6a..15c915a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,10 @@ packages = [{include = "lanarky"}] [tool.poetry.dependencies] python = "^3.9" fastapi = ">=0.95.2" -langchain = ">=0.0.175" +langchain = ">=0.0.178" urllib3 = "<=1.26.15" # added due to poetry errors python-dotenv = "^1.0.0" +typing-extensions = "4.5.0" # added due to https://github.com/hwchase17/langchain/issues/5113 [tool.poetry.group.dev.dependencies] pre-commit = "^3.3.1" diff --git a/tests/callbacks/test_base.py b/tests/callbacks/test_base.py index 4dd7d39..46eac9d 100644 --- a/tests/callbacks/test_base.py +++ b/tests/callbacks/test_base.py @@ -1,4 +1,5 @@ from lanarky.callbacks.base import ( + AsyncStreamingJSONResponseCallback, AsyncStreamingResponseCallback, AsyncWebsocketCallback, ) @@ -10,3 +11,6 @@ def test_always_verbose(send, websocket, bot_response): callback = AsyncWebsocketCallback(websocket=websocket, response=bot_response) assert callback.always_verbose is True + + callback = AsyncStreamingJSONResponseCallback(send=send) + assert callback.always_verbose is True diff --git a/tests/callbacks/test_init.py b/tests/callbacks/test_init.py index c6c46fb..498dc9a 100644 --- a/tests/callbacks/test_init.py +++ b/tests/callbacks/test_init.py @@ -6,7 +6,9 @@ from lanarky.callbacks import ( AsyncLLMChainStreamingCallback, AsyncLLMChainWebsocketCallback, + AsyncStreamingJSONResponseCallback, get_streaming_callback, + get_streaming_json_callback, get_websocket_callback, ) @@ -37,3 +39,17 @@ class CustomChain: chain = Mock(spec=CustomChain) get_websocket_callback(chain, websocket=websocket, response=bot_response) + + +def test_get_streaming_json_callback(send): + chain = Mock(spec=LLMChain) + callback = get_streaming_json_callback(chain, send=send) + assert isinstance(callback, AsyncStreamingJSONResponseCallback) + + with pytest.raises(KeyError): + + class CustomChain: + pass + + chain = Mock(spec=CustomChain) + get_streaming_json_callback(chain, send=send) diff --git a/tests/callbacks/test_llm.py b/tests/callbacks/test_llm.py index 062d9a2..059be82 100644 --- a/tests/callbacks/test_llm.py +++ b/tests/callbacks/test_llm.py @@ -2,8 +2,10 @@ from lanarky.callbacks.llm import ( AsyncLLMChainStreamingCallback, + AsyncLLMChainStreamingJSONCallback, AsyncLLMChainWebsocketCallback, ) +from lanarky.schemas import StreamingJSONResponse @pytest.mark.asyncio @@ -30,3 +32,14 @@ async def test_async_llm_chain_websocket_callback_on_llm_new_token( message = callback._construct_message("test_token") callback.websocket.send_json.assert_awaited_once_with(message) + + +@pytest.mark.asyncio +async def test_async_llm_chain_streaming_json_callback_on_llm_new_token(send): + callback = AsyncLLMChainStreamingJSONCallback(send=send) + + await callback.on_llm_new_token("test_token") + + message = callback._construct_message(StreamingJSONResponse(token="test_token")) + + callback.send.assert_awaited_once_with(message) diff --git a/tests/callbacks/test_retrieval_qa.py b/tests/callbacks/test_retrieval_qa.py index 6a3eb9f..2d39d73 100644 --- a/tests/callbacks/test_retrieval_qa.py +++ b/tests/callbacks/test_retrieval_qa.py @@ -4,8 +4,10 @@ from lanarky.callbacks.retrieval_qa import ( AsyncBaseRetrievalQAStreamingCallback, + AsyncBaseRetrievalQAStreamingJSONCallback, AsyncBaseRetrievalQAWebsocketCallback, ) +from lanarky.schemas import BaseRetrievalQAStreamingJSONResponse @pytest.fixture @@ -21,7 +23,7 @@ def outputs(): @pytest.fixture def messages(): return [ - "\n\nSOURCE DOCUMENTS: \n", + "\n\nSOURCE DOCUMENTS:\n", "\npage content: Page 1 content\nsource: Source 1\n", "\npage content: Page 2 content\nsource: Source 2\n", ] @@ -40,12 +42,27 @@ async def test_streaming_on_chain_end(send, outputs, messages): @pytest.mark.asyncio async def test_websocket_on_chain_end(websocket, bot_response, outputs, messages): - ws_callback = AsyncBaseRetrievalQAWebsocketCallback( + callback = AsyncBaseRetrievalQAWebsocketCallback( websocket=websocket, response=bot_response, ) - await ws_callback.on_chain_end(outputs) + await callback.on_chain_end(outputs) + + callback.websocket.send_json.assert_has_calls( + [call(callback._construct_message(message)) for message in messages] + ) + + +@pytest.mark.asyncio +async def test_streaming_json_on_chain_end(send, outputs): + callback = AsyncBaseRetrievalQAStreamingJSONCallback(send=send) + + await callback.on_chain_end(outputs) + + source_documents = [document.dict() for document in outputs["source_documents"]] - ws_callback.websocket.send_json.assert_has_calls( - [call(ws_callback._construct_message(message)) for message in messages] + callback.send.assert_awaited_once_with( + callback._construct_message( + BaseRetrievalQAStreamingJSONResponse(source_documents=source_documents) + ) ) diff --git a/tests/conftest.py b/tests/conftest.py index 995cf3e..ba9e222 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,23 +2,23 @@ import pytest from fastapi import WebSocket -from langchain.chains import LLMChain +from langchain.chains.llm import LLMChain from starlette.types import Send from lanarky.schemas import Message, MessageType, Sender, WebsocketResponse -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def send(): return AsyncMock(spec=Send) -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def websocket(): return AsyncMock(spec=WebSocket) -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def bot_response(): return WebsocketResponse( sender=Sender.BOT, message=Message.NULL, message_type=MessageType.STREAM diff --git a/tests/responses/test_streaming.py b/tests/responses/test_streaming.py index ebcfad6..ac62c4d 100644 --- a/tests/responses/test_streaming.py +++ b/tests/responses/test_streaming.py @@ -3,8 +3,8 @@ import pytest from starlette.background import BackgroundTask -from lanarky.callbacks import get_streaming_callback -from lanarky.responses.streaming import StreamingResponse +from lanarky.callbacks import get_streaming_callback, get_streaming_json_callback +from lanarky.responses import StreamingResponse @pytest.fixture @@ -33,7 +33,7 @@ def test_init_from_chain(streaming_response: StreamingResponse) -> None: @pytest.mark.asyncio -async def test_create_chain_executor( +async def test_streaming_create_chain_executor( chain: MagicMock, inputs: dict[str, str], send ) -> None: chain_executor = StreamingResponse._create_chain_executor( @@ -82,3 +82,22 @@ async def test_stream_response_error( await streaming_response.stream_response(send) assert background.kwargs["outputs"] == "Something went wrong" + + +@pytest.mark.asyncio +async def test_streaming_json_create_chain_executor( + chain: MagicMock, inputs: dict[str, str], send +) -> None: + chain_executor = StreamingResponse._create_chain_executor( + chain=chain, inputs=inputs, as_json=True + ) + + assert callable(chain_executor) + + chain.acall.assert_not_called() + + await chain_executor(send) + + chain.acall.assert_called_once_with( + inputs=inputs, callbacks=[get_streaming_json_callback(chain=chain, send=send)] + )