From 4e8e587998467c6b963fc13185424edcd6cd4c2e Mon Sep 17 00:00:00 2001 From: Ajinkya Indulkar <26824103+ajndkr@users.noreply.github.com> Date: Tue, 30 May 2023 21:45:19 +0200 Subject: [PATCH] feat: Allow overriding registered callbacks (#75) * :zap: update `LangchainRouter` import * :zap: add `override` param to register * :memo: update docs --- README.md | 2 +- docs/getting_started.rst | 2 +- docs/lanarky/lanarky.callbacks.rst | 21 ++++----- docs/lanarky/lanarky.register.rst | 15 +++---- docs/lanarky/lanarky.responses.rst | 9 ++-- docs/lanarky/lanarky.routing.rst | 13 +++--- docs/lanarky/lanarky.rst | 10 +---- docs/lanarky/lanarky.schemas.rst | 2 +- docs/lanarky/lanarky.testing.rst | 9 ++-- docs/lanarky/lanarky.websockets.rst | 5 +-- docs/langchain/cache.rst | 5 ++- docs/langchain/deploy.rst | 4 +- examples/app/conversation_chain.py | 2 +- examples/app/retrieval_qa_w_sources.py | 2 +- examples/app/zero_shot_agent.py | 2 +- examples/requirements.txt | 2 +- lanarky/__init__.py | 3 ++ lanarky/register/base.py | 8 +++- lanarky/register/callbacks.py | 12 +++--- tests/test_register.py | 60 ++++++++++++++++++++++++++ 20 files changed, 117 insertions(+), 71 deletions(-) diff --git a/README.md b/README.md index 6e61dbc..54a9cfd 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ from fastapi import FastAPI from langchain import ConversationChain from langchain.chat_models import ChatOpenAI -from lanarky.routing import LangchainRouter +from lanarky import LangchainRouter load_dotenv() app = FastAPI() diff --git a/docs/getting_started.rst b/docs/getting_started.rst index eb2c36b..1d80a6f 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -13,7 +13,7 @@ You can get quickly started with Lanarky and deploy your first Langchain app in from langchain import ConversationChain from langchain.chat_models import ChatOpenAI - from lanarky.routing import LangchainRouter + from lanarky import LangchainRouter load_dotenv() app = FastAPI() diff --git a/docs/lanarky/lanarky.callbacks.rst b/docs/lanarky/lanarky.callbacks.rst index 06ee90a..4e00292 100644 --- a/docs/lanarky/lanarky.callbacks.rst +++ b/docs/lanarky/lanarky.callbacks.rst @@ -1,40 +1,37 @@ callbacks -========================= +========= .. automodule:: lanarky.callbacks :members: :undoc-members: :show-inheritance: -Submodules ----------- - -lanarky.callbacks.base module ------------------------------ +lanarky.callbacks.base +---------------------- .. automodule:: lanarky.callbacks.base :members: :undoc-members: :show-inheritance: -lanarky.callbacks.llm module ----------------------------- +lanarky.callbacks.llm +--------------------- .. automodule:: lanarky.callbacks.llm :members: :undoc-members: :show-inheritance: -lanarky.callbacks.retrieval\_qa module --------------------------------------- +lanarky.callbacks.retrieval\_qa +------------------------------- .. automodule:: lanarky.callbacks.retrieval_qa :members: :undoc-members: :show-inheritance: -lanarky.callbacks.agents module --------------------------------------- +lanarky.callbacks.agents +------------------------- .. automodule:: lanarky.callbacks.agents :members: diff --git a/docs/lanarky/lanarky.register.rst b/docs/lanarky/lanarky.register.rst index 363d2ce..c79a1a6 100644 --- a/docs/lanarky/lanarky.register.rst +++ b/docs/lanarky/lanarky.register.rst @@ -1,24 +1,21 @@ -registry ------------------------ +register +========= .. automodule:: lanarky.register :members: :undoc-members: :show-inheritance: -Submodules ----------- - -lanarky.register.base module ------------------------------ +lanarky.register.base +---------------------- .. automodule:: lanarky.register.base :members: :undoc-members: :show-inheritance: -lanarky.register.callbacks module ----------------------------- +lanarky.register.callbacks +-------------------------- .. automodule:: lanarky.register.callbacks :members: diff --git a/docs/lanarky/lanarky.responses.rst b/docs/lanarky/lanarky.responses.rst index d985d82..45caab0 100644 --- a/docs/lanarky/lanarky.responses.rst +++ b/docs/lanarky/lanarky.responses.rst @@ -1,16 +1,13 @@ responses -========================= +========= .. automodule:: lanarky.responses :members: :undoc-members: :show-inheritance: -Submodules ----------- - -lanarky.responses.streaming module ----------------------------------- +lanarky.responses.streaming +---------------------------- .. automodule:: lanarky.responses.streaming :members: diff --git a/docs/lanarky/lanarky.routing.rst b/docs/lanarky/lanarky.routing.rst index a89f70b..bb0f8b5 100644 --- a/docs/lanarky/lanarky.routing.rst +++ b/docs/lanarky/lanarky.routing.rst @@ -1,24 +1,21 @@ routing -========================= +======== .. automodule:: lanarky.routing :members: :undoc-members: :show-inheritance: -Submodules ----------- - -lanarky.routing.langchain module ----------------------------------- +lanarky.routing.langchain +-------------------------- .. automodule:: lanarky.routing.langchain :members: :undoc-members: :show-inheritance: -lanarky.routing.utils module ----------------------------------- +lanarky.routing.utils +---------------------- .. automodule:: lanarky.routing.utils :members: diff --git a/docs/lanarky/lanarky.rst b/docs/lanarky/lanarky.rst index 83d8a36..1a0ad62 100644 --- a/docs/lanarky/lanarky.rst +++ b/docs/lanarky/lanarky.rst @@ -1,16 +1,10 @@ lanarky =============== -.. automodule:: lanarky - :members: - :undoc-members: - :show-inheritance: - -Modules ------------ +Welcome to Lanarky's API Reference! .. toctree:: - :maxdepth: 1 + :maxdepth: 2 lanarky.callbacks lanarky.responses diff --git a/docs/lanarky/lanarky.schemas.rst b/docs/lanarky/lanarky.schemas.rst index 4f6ffda..3ef4067 100644 --- a/docs/lanarky/lanarky.schemas.rst +++ b/docs/lanarky/lanarky.schemas.rst @@ -1,5 +1,5 @@ schemas ----------------------- +======= .. automodule:: lanarky.schemas :members: diff --git a/docs/lanarky/lanarky.testing.rst b/docs/lanarky/lanarky.testing.rst index ba3f65f..d67841b 100644 --- a/docs/lanarky/lanarky.testing.rst +++ b/docs/lanarky/lanarky.testing.rst @@ -1,15 +1,12 @@ testing -======================= +======= .. automodule:: lanarky.testing :members: :undoc-members: :show-inheritance: -Submodules ----------- - -lanarky.testing.gradio module +lanarky.testing.gradio ----------------------------- .. automodule:: lanarky.testing.gradio @@ -17,7 +14,7 @@ lanarky.testing.gradio module :undoc-members: :show-inheritance: -lanarky.testing.settings module +lanarky.testing.settings ------------------------------- .. automodule:: lanarky.testing.settings diff --git a/docs/lanarky/lanarky.websockets.rst b/docs/lanarky/lanarky.websockets.rst index 58c2337..0a17d59 100644 --- a/docs/lanarky/lanarky.websockets.rst +++ b/docs/lanarky/lanarky.websockets.rst @@ -6,10 +6,7 @@ websockets :undoc-members: :show-inheritance: -Submodules ----------- - -lanarky.websockets.base module +lanarky.websockets.base ------------------------------ .. automodule:: lanarky.websockets.base diff --git a/docs/langchain/cache.rst b/docs/langchain/cache.rst index d173de8..25a7ed1 100644 --- a/docs/langchain/cache.rst +++ b/docs/langchain/cache.rst @@ -14,7 +14,7 @@ We'll use a simple ``LLMChain`` application as an example: from langchain import LLMChain from langchain.llms import OpenAI - from lanarky.routing import LangchainRouter + from lanarky import LangchainRouter load_dotenv() @@ -44,6 +44,7 @@ In-Memory Caching To setup in-memory caching, use the following ``LangchainRouter`` configuration: .. code-block:: python + langchain_router = LangchainRouter( langchain_url="/chat", langchain_object=LLMChain.from_string( @@ -72,6 +73,7 @@ Next, setup a Redis server. We recommend using Docker: Finally, use the following ``LangchainRouter`` configuration: .. code-block:: python + langchain_router = LangchainRouter( langchain_url="/chat", langchain_object=LLMChain.from_string( @@ -95,6 +97,7 @@ To setup GPTCache caching, first install the required dependencies: Then, use the following ``LangchainRouter`` configuration: .. code-block:: python + langchain_router = LangchainRouter( langchain_url="/chat", langchain_object=LLMChain.from_string( diff --git a/docs/langchain/deploy.rst b/docs/langchain/deploy.rst index 3af8f53..23fb253 100644 --- a/docs/langchain/deploy.rst +++ b/docs/langchain/deploy.rst @@ -13,7 +13,7 @@ To better understand ``LangchainRouter``, let's break down the example below: from fastapi import FastAPI from langchain import ConversationChain from langchain.chat_models import ChatOpenAI - from lanarky.routing import LangchainRouter + from lanarky import LangchainRouter load_dotenv() app = FastAPI() @@ -40,7 +40,7 @@ Here's an example: from fastapi import FastAPI from langchain import ConversationChain from langchain.chat_models import ChatOpenAI - from lanarky.routing import LangchainRouter + from lanarky import LangchainRouter load_dotenv() app = FastAPI() diff --git a/examples/app/conversation_chain.py b/examples/app/conversation_chain.py index 81850f4..718bebd 100644 --- a/examples/app/conversation_chain.py +++ b/examples/app/conversation_chain.py @@ -4,7 +4,7 @@ from langchain import ConversationChain from langchain.chat_models import ChatOpenAI -from lanarky.routing import LangchainRouter +from lanarky import LangchainRouter from lanarky.testing import mount_gradio_app load_dotenv() diff --git a/examples/app/retrieval_qa_w_sources.py b/examples/app/retrieval_qa_w_sources.py index eaed31e..8696d78 100644 --- a/examples/app/retrieval_qa_w_sources.py +++ b/examples/app/retrieval_qa_w_sources.py @@ -6,7 +6,7 @@ from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import FAISS -from lanarky.routing import LangchainRouter +from lanarky import LangchainRouter from lanarky.testing import mount_gradio_app load_dotenv() diff --git a/examples/app/zero_shot_agent.py b/examples/app/zero_shot_agent.py index 2b4dd7a..2a127a5 100644 --- a/examples/app/zero_shot_agent.py +++ b/examples/app/zero_shot_agent.py @@ -4,7 +4,7 @@ from langchain.agents import AgentExecutor, AgentType, initialize_agent, load_tools from langchain.chat_models import ChatOpenAI -from lanarky.routing import LangchainRouter +from lanarky import LangchainRouter from lanarky.testing import mount_gradio_app load_dotenv() diff --git a/examples/requirements.txt b/examples/requirements.txt index 6074f10..b2cc39f 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -104,7 +104,7 @@ jsonschema==4.17.3 # via altair kiwisolver==1.4.4 # via matplotlib -lanarky==0.7.2 +lanarky==0.7.4 # via -r requirements.in langchain==0.0.183 # via diff --git a/lanarky/__init__.py b/lanarky/__init__.py index e69de29..5dca670 100644 --- a/lanarky/__init__.py +++ b/lanarky/__init__.py @@ -0,0 +1,3 @@ +from .routing import LangchainRouter + +__all__ = ["LangchainRouter"] diff --git a/lanarky/register/base.py b/lanarky/register/base.py index 69f50bb..aa970ac 100644 --- a/lanarky/register/base.py +++ b/lanarky/register/base.py @@ -2,7 +2,10 @@ def register( - key: Union[list[str], str], _registry: dict[str, tuple[Any, list[str]]] + key: Union[list[str], str], + _registry: dict[str, tuple[Any, list[str]]], + *, + override: bool = False, ) -> Any: """Add a class/function to a registry with required keyword arguments. @@ -13,6 +16,7 @@ def register( Args: key: key or list of keys to register the class/function under. _registry: registry to add the class/function to. + override: if True, override existing keys in the registry. """ def _register_cls(cls: Any, required_kwargs: Optional[list] = None) -> Any: @@ -22,7 +26,7 @@ def _register_cls(cls: Any, required_kwargs: Optional[list] = None) -> Any: keys = key for _key in keys: - if _key in _registry: + if _key in _registry and not override: raise KeyError(f"{cls} already registered as {_key}") _registry[_key] = cls if required_kwargs is None else (cls, required_kwargs) return cls diff --git a/lanarky/register/callbacks.py b/lanarky/register/callbacks.py index aee7abe..1c52f59 100644 --- a/lanarky/register/callbacks.py +++ b/lanarky/register/callbacks.py @@ -7,31 +7,31 @@ STREAMING_JSON_CALLBACKS: dict[str, Any] = {} -def register_streaming_callback(key: Union[list[str], str]) -> Callable: +def register_streaming_callback(key: Union[list[str], str], **kwargs) -> Callable: """Register a streaming callback handler.""" def _register_cls(cls: Any) -> Callable: - register(key, STREAMING_CALLBACKS)(cls=cls) + register(key, STREAMING_CALLBACKS, **kwargs)(cls=cls) return cls return _register_cls -def register_websocket_callback(key: Union[list[str], str]) -> Callable: +def register_websocket_callback(key: Union[list[str], str], **kwargs) -> Callable: """Register a websocket callback handler.""" def _register_cls(cls: Any) -> Callable: - register(key, WEBSOCKET_CALLBACKS)(cls=cls) + register(key, WEBSOCKET_CALLBACKS, **kwargs)(cls=cls) return cls return _register_cls -def register_streaming_json_callback(key: Union[list[str], str]) -> Callable: +def register_streaming_json_callback(key: Union[list[str], str], **kwargs) -> Callable: """Register an streaming json callback handler.""" def _register_cls(cls: Any) -> Callable: - register(key, STREAMING_JSON_CALLBACKS)(cls=cls) + register(key, STREAMING_JSON_CALLBACKS, **kwargs)(cls=cls) return cls return _register_cls diff --git a/tests/test_register.py b/tests/test_register.py index abb118c..bab8e3d 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -4,8 +4,10 @@ from lanarky.register import ( STREAMING_CALLBACKS, + STREAMING_JSON_CALLBACKS, WEBSOCKET_CALLBACKS, register_streaming_callback, + register_streaming_json_callback, register_websocket_callback, ) from lanarky.register.base import register @@ -26,6 +28,12 @@ class MyClass: with pytest.raises(KeyError): register("my_key", registry)(cls=MyClass) + class MyOtherClass: + pass + + register("my_key", registry, override=True)(cls=MyOtherClass) + assert registry["my_key"] == MyOtherClass + def test_register_with_required_kwargs(registry: Dict[str, Tuple[Any, List[str]]]): class MyClass: @@ -46,6 +54,17 @@ class MyStreamingCallback: assert "my_streaming_callback" in STREAMING_CALLBACKS assert STREAMING_CALLBACKS["my_streaming_callback"] == MyStreamingCallback + with pytest.raises(KeyError): + register_streaming_callback("my_streaming_callback")(cls=MyStreamingCallback) + + class MyOtherStreamingCallback: + pass + + register_streaming_callback("my_streaming_callback", override=True)( + cls=MyOtherStreamingCallback + ) + assert STREAMING_CALLBACKS["my_streaming_callback"] == MyOtherStreamingCallback + def test_register_websocket_callback(registry: Dict[str, Tuple[Any, List[str]]]): class MyWebsocketCallback: @@ -54,3 +73,44 @@ class MyWebsocketCallback: register_websocket_callback("my_websocket_callback")(cls=MyWebsocketCallback) assert "my_websocket_callback" in WEBSOCKET_CALLBACKS assert WEBSOCKET_CALLBACKS["my_websocket_callback"] == MyWebsocketCallback + + with pytest.raises(KeyError): + register_websocket_callback("my_websocket_callback")(cls=MyWebsocketCallback) + + class MyOtherWebsocketCallback: + pass + + register_websocket_callback("my_websocket_callback", override=True)( + cls=MyOtherWebsocketCallback + ) + + assert WEBSOCKET_CALLBACKS["my_websocket_callback"] == MyOtherWebsocketCallback + + +def test_register_streaming_json_callback(registry: Dict[str, Tuple[Any, List[str]]]): + class MyStreamingJsonCallback: + pass + + register_streaming_json_callback("my_streaming_json_callback")( + cls=MyStreamingJsonCallback + ) + assert ( + STREAMING_JSON_CALLBACKS["my_streaming_json_callback"] + == MyStreamingJsonCallback + ) + + with pytest.raises(KeyError): + register_streaming_json_callback("my_streaming_json_callback")( + cls=MyStreamingJsonCallback + ) + + class MyOtherStreamingJsonCallback: + pass + + register_streaming_json_callback("my_streaming_json_callback", override=True)( + cls=MyOtherStreamingJsonCallback + ) + assert ( + STREAMING_JSON_CALLBACKS["my_streaming_json_callback"] + == MyOtherStreamingJsonCallback + )