From ee914098b44f5897e789b84d6168ed31b5ecb288 Mon Sep 17 00:00:00 2001 From: Ajinkya Indulkar <26824103+ajndkr@users.noreply.github.com> Date: Mon, 29 May 2023 20:15:18 +0200 Subject: [PATCH] feat: add LLM caching (#72) * :sparkles: add llm cache to LangchainRouter * :heavy_plus_sign: add `cache` extras * :memo: update README * :recycle: refactor `register` module * :white_check_mark: update unit tests * :memo: update docs * :heavy_plus_sign: add `gptcache` extras * :sparkles: add gptcache integration * :memo: update docs * :memo: update README --- README.md | 6 +- docs/lanarky/lanarky.register.rst | 19 ++++ docs/langchain/cache.rst | 105 ++++++++++++++++++++++ docs/langchain/index.rst | 1 + lanarky/register/__init__.py | 17 ++++ lanarky/{register.py => register/base.py} | 37 +------- lanarky/register/callbacks.py | 37 ++++++++ lanarky/routing/__init__.py | 3 +- lanarky/routing/langchain.py | 59 +++++++++++- lanarky/routing/utils.py | 7 ++ pyproject.toml | 6 ++ tests/routing/test_langchain_router.py | 19 +++- tests/test_register.py | 2 +- 13 files changed, 275 insertions(+), 43 deletions(-) create mode 100644 docs/langchain/cache.rst create mode 100644 lanarky/register/__init__.py rename lanarky/{register.py => register/base.py} (50%) create mode 100644 lanarky/register/callbacks.py diff --git a/README.md b/README.md index 2343a9f..6e61dbc 100644 --- a/README.md +++ b/README.md @@ -75,10 +75,10 @@ before running the examples. - [x] Add support for [LangChain](https://github.com/hwchase17/langchain) - [x] Add [Gradio](https://github.com/gradio-app/gradio) UI for fast prototyping -- [ ] Add SQL database integration -- [ ] Add support for [Guardrails](https://github.com/ShreyaR/guardrails) +- [x] Add support for in-memory, Redis and [GPTCache](https://github.com/zilliztech/GPTCache) LLM caching - [ ] Add support for [LlamaIndex](https://github.com/jerryjliu/llama_index) -- [ ] Add [GPTCache](https://github.com/zilliztech/GPTCache) integration +- [ ] Add SQL database integration +- [ ] Add support for [Rebuff](https://github.com/woop/rebuff) ## 🤩 Stargazers diff --git a/docs/lanarky/lanarky.register.rst b/docs/lanarky/lanarky.register.rst index 2295709..363d2ce 100644 --- a/docs/lanarky/lanarky.register.rst +++ b/docs/lanarky/lanarky.register.rst @@ -5,3 +5,22 @@ registry :members: :undoc-members: :show-inheritance: + +Submodules +---------- + +lanarky.register.base module +----------------------------- + +.. automodule:: lanarky.register.base + :members: + :undoc-members: + :show-inheritance: + +lanarky.register.callbacks module +---------------------------- + +.. automodule:: lanarky.register.callbacks + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/langchain/cache.rst b/docs/langchain/cache.rst new file mode 100644 index 0000000..d173de8 --- /dev/null +++ b/docs/langchain/cache.rst @@ -0,0 +1,105 @@ +Deploy Applications with LLM Caching +===================================== + +Langchain offers multiple LLM cache solutions. Reference: `How to cache LLM calls `_ + +To simplify the use of these solutions for Lanarky users, the ``LangchainRouter`` can be used to setup LLM caching for your application. + +We'll use a simple ``LLMChain`` application as an example: + +.. code-block:: python + + from dotenv import load_dotenv + from fastapi import FastAPI + from langchain import LLMChain + from langchain.llms import OpenAI + + from lanarky.routing import LangchainRouter + + load_dotenv() + + app = FastAPI() + + langchain_router = LangchainRouter( + langchain_url="/chat", + langchain_object=LLMChain.from_string( + llm=OpenAI(temperature=0), template="Answer the query.\n{query}" + ), + streaming_mode=0, + ) + + app.include_router(langchain_router) + +The ``LangchainRouter`` class uses the ``llm_cache_mode`` parameter to setup LLM caching. +There are three available modes: + +- ``llm_cache_mode=0``: No LLM caching +- ``llm_cache_mode=1``: In-memory LLM caching +- ``llm_cache_mode=2``: Redis LLM caching +- ``llm_cache_mode=3``: GPTCache LLM caching + +In-Memory Caching +----------------- + +To setup in-memory caching, use the following ``LangchainRouter`` configuration: + +.. code-block:: python + langchain_router = LangchainRouter( + langchain_url="/chat", + langchain_object=LLMChain.from_string( + llm=OpenAI(temperature=0), template="Answer the query.\n{query}" + ), + streaming_mode=0, + llm_cache_mode=1, + ) + + +Redis Caching +------------- + +To setup Redis caching, first install the required dependencies: + +.. code-block:: bash + + pip install "lanarky[redis]" + +Next, setup a Redis server. We recommend using Docker: + +.. code-block:: bash + + docker run --name redis -p 6379:6379 -d redis + +Finally, use the following ``LangchainRouter`` configuration: + +.. code-block:: python + langchain_router = LangchainRouter( + langchain_url="/chat", + langchain_object=LLMChain.from_string( + llm=OpenAI(temperature=0), template="Answer the query.\n{query}" + ), + streaming_mode=0, + llm_cache_mode=2, + llm_cache_kwargs={"url": "redis://localhost:6379/"}, + ) + + +GPTCache Caching +---------------- + +To setup GPTCache caching, first install the required dependencies: + +.. code-block:: bash + + pip install "lanarky[gptcache]" + +Then, use the following ``LangchainRouter`` configuration: + +.. code-block:: python + langchain_router = LangchainRouter( + langchain_url="/chat", + langchain_object=LLMChain.from_string( + llm=OpenAI(temperature=0), template="Answer the query.\n{query}" + ), + streaming_mode=0, + llm_cache_mode=3, + ) diff --git a/docs/langchain/index.rst b/docs/langchain/index.rst index 1466a0d..761ce06 100644 --- a/docs/langchain/index.rst +++ b/docs/langchain/index.rst @@ -12,4 +12,5 @@ your Langchain application. :maxdepth: 1 deploy + cache custom_callbacks diff --git a/lanarky/register/__init__.py b/lanarky/register/__init__.py new file mode 100644 index 0000000..f4b09c4 --- /dev/null +++ b/lanarky/register/__init__.py @@ -0,0 +1,17 @@ +from .callbacks import ( + STREAMING_CALLBACKS, + STREAMING_JSON_CALLBACKS, + WEBSOCKET_CALLBACKS, + register_streaming_callback, + register_streaming_json_callback, + register_websocket_callback, +) + +__all__ = [ + "STREAMING_CALLBACKS", + "STREAMING_JSON_CALLBACKS", + "WEBSOCKET_CALLBACKS", + "register_streaming_callback", + "register_streaming_json_callback", + "register_websocket_callback", +] diff --git a/lanarky/register.py b/lanarky/register/base.py similarity index 50% rename from lanarky/register.py rename to lanarky/register/base.py index 7dab34e..69f50bb 100644 --- a/lanarky/register.py +++ b/lanarky/register/base.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, Union +from typing import Any, Optional, Union def register( @@ -28,38 +28,3 @@ def _register_cls(cls: Any, required_kwargs: Optional[list] = None) -> Any: return cls return _register_cls - - -STREAMING_CALLBACKS: dict[str, Any] = {} -WEBSOCKET_CALLBACKS: dict[str, Any] = {} -STREAMING_JSON_CALLBACKS: dict[str, Any] = {} - - -def register_streaming_callback(key: Union[list[str], str]) -> Callable: - """Register a streaming callback handler.""" - - def _register_cls(cls: Any) -> Callable: - register(key, STREAMING_CALLBACKS)(cls=cls) - return cls - - return _register_cls - - -def register_websocket_callback(key: Union[list[str], str]) -> Callable: - """Register a websocket callback handler.""" - - def _register_cls(cls: Any) -> Callable: - register(key, WEBSOCKET_CALLBACKS)(cls=cls) - return cls - - return _register_cls - - -def register_streaming_json_callback(key: Union[list[str], str]) -> Callable: - """Register an streaming json callback handler.""" - - def _register_cls(cls: Any) -> Callable: - register(key, STREAMING_JSON_CALLBACKS)(cls=cls) - return cls - - return _register_cls diff --git a/lanarky/register/callbacks.py b/lanarky/register/callbacks.py new file mode 100644 index 0000000..aee7abe --- /dev/null +++ b/lanarky/register/callbacks.py @@ -0,0 +1,37 @@ +from typing import Any, Callable, Union + +from .base import register + +STREAMING_CALLBACKS: dict[str, Any] = {} +WEBSOCKET_CALLBACKS: dict[str, Any] = {} +STREAMING_JSON_CALLBACKS: dict[str, Any] = {} + + +def register_streaming_callback(key: Union[list[str], str]) -> Callable: + """Register a streaming callback handler.""" + + def _register_cls(cls: Any) -> Callable: + register(key, STREAMING_CALLBACKS)(cls=cls) + return cls + + return _register_cls + + +def register_websocket_callback(key: Union[list[str], str]) -> Callable: + """Register a websocket callback handler.""" + + def _register_cls(cls: Any) -> Callable: + register(key, WEBSOCKET_CALLBACKS)(cls=cls) + return cls + + return _register_cls + + +def register_streaming_json_callback(key: Union[list[str], str]) -> Callable: + """Register an streaming json callback handler.""" + + def _register_cls(cls: Any) -> Callable: + register(key, STREAMING_JSON_CALLBACKS)(cls=cls) + return cls + + return _register_cls diff --git a/lanarky/routing/__init__.py b/lanarky/routing/__init__.py index fec21ce..00b2cbc 100644 --- a/lanarky/routing/__init__.py +++ b/lanarky/routing/__init__.py @@ -1,3 +1,4 @@ from .langchain import LangchainRouter +from .utils import LLMCacheMode, StreamingMode -__all__ = ["LangchainRouter"] +__all__ = ["LangchainRouter", "StreamingMode", "LLMCacheMode"] diff --git a/lanarky/routing/langchain.py b/lanarky/routing/langchain.py index e7cfe8d..b7a6aa3 100644 --- a/lanarky/routing/langchain.py +++ b/lanarky/routing/langchain.py @@ -7,6 +7,7 @@ from langchain.chains.base import Chain from .utils import ( + LLMCacheMode, StreamingMode, create_langchain_dependency, create_langchain_endpoint, @@ -24,6 +25,8 @@ def __init__( langchain_object: Optional[Type[Chain]] = None, langchain_endpoint_kwargs: Optional[dict[str, Any]] = None, streaming_mode: Optional[StreamingMode] = None, + llm_cache_mode: Optional[LLMCacheMode] = None, + llm_cache_kwargs: Optional[dict[str, Any]] = None, **kwargs, ): super().__init__(**kwargs) @@ -32,6 +35,8 @@ def __init__( self.langchain_object = langchain_object self.langchain_endpoint_kwargs = langchain_endpoint_kwargs or {} self.streaming_mode = streaming_mode + self.llm_cache_mode = llm_cache_mode + self.llm_cache_kwargs = llm_cache_kwargs or {} self.langchain_dependencies = [] @@ -39,7 +44,7 @@ def __init__( def setup(self) -> None: """Sets up the Langchain router.""" - if self.langchain_url: + if self.langchain_url is not None: self.add_langchain_api_route( self.langchain_url, self.langchain_object, @@ -47,6 +52,58 @@ def setup(self) -> None: **self.langchain_endpoint_kwargs, ) + if self.llm_cache_mode is not None: + self.setup_llm_cache() + + def setup_llm_cache(self) -> None: + """Sets up the LLM cache.""" + import langchain + + if self.llm_cache_mode == LLMCacheMode.IN_MEMORY: + from langchain.cache import InMemoryCache + + langchain.llm_cache = InMemoryCache() + + elif self.llm_cache_mode == LLMCacheMode.REDIS: + try: + from redis import Redis # type: ignore + except ImportError: + raise ImportError( + """Redis is not installed. Install it with `pip install "lanarky[redis]"`.""" + ) + from langchain.cache import RedisCache + + langchain.llm_cache = RedisCache( + redis_=Redis.from_url(**self.llm_cache_kwargs) + ) + + elif self.llm_cache_mode == LLMCacheMode.GPTCACHE: + try: + from gptcache import Cache # type: ignore + from gptcache.manager.factory import manager_factory # type: ignore + from gptcache.processor.pre import get_prompt # type: ignore + except ImportError: + raise ImportError( + """GPTCache is not installed. Install it with `pip install "lanarky[gptcache]"`.""" + ) + import hashlib + + from langchain.cache import GPTCache + + def init_gptcache(cache_obj: Cache, llm: str): + hashed_llm = hashlib.sha256(llm.encode()).hexdigest() + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=manager_factory( + manager="map", data_dir=f"map_cache_{hashed_llm}" + ), + ) + + langchain.llm_cache = GPTCache(init_gptcache) + + else: + raise ValueError(f"Invalid LLM cache mode: {self.llm_cache_mode}") + def add_langchain_api_route( self, url: str, diff --git a/lanarky/routing/utils.py b/lanarky/routing/utils.py index 3996e5a..7260595 100644 --- a/lanarky/routing/utils.py +++ b/lanarky/routing/utils.py @@ -16,6 +16,13 @@ class StreamingMode(IntEnum): JSON = 2 +class LLMCacheMode(IntEnum): + OFF = 0 + IN_MEMORY = 1 + REDIS = 2 + GPTCACHE = 3 + + def create_langchain_dependency(langchain_object: Type[Chain]) -> params.Depends: """Creates a langchain object dependency.""" diff --git a/pyproject.toml b/pyproject.toml index c12ad07..19b5eac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ langchain = ">=0.0.183" urllib3 = "<=1.26.15" # added due to poetry errors python-dotenv = "^1.0.0" typing-extensions = "4.5.0" # added due to https://github.com/hwchase17/langchain/issues/5113 +redis = {version = "^4.5.5", optional = true} +gptcache = {version = "^0.1.28", optional = true} [tool.poetry.group.dev.dependencies] pre-commit = "^3.3.1" @@ -36,6 +38,10 @@ pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.0" coveralls = "^3.3.1" +[tool.poetry.extras] +redis = ["redis"] +gptcache = ["gptcache"] + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/routing/test_langchain_router.py b/tests/routing/test_langchain_router.py index 596d2b3..cd28abb 100644 --- a/tests/routing/test_langchain_router.py +++ b/tests/routing/test_langchain_router.py @@ -3,7 +3,7 @@ import pytest from langchain.chains import ConversationChain -from lanarky.routing import LangchainRouter +from lanarky.routing import LangchainRouter, LLMCacheMode @pytest.fixture @@ -19,6 +19,8 @@ def test_langchain_router_init(): assert router.langchain_url is None assert router.langchain_endpoint_kwargs == {} assert router.langchain_dependencies == [] + assert router.llm_cache_mode is None + assert router.llm_cache_kwargs == {} def test_langchain_router_add_routes(chain): @@ -56,3 +58,18 @@ def test_langchain_router_add_routes(chain): assert router.routes[2].path == "/chat" assert router.routes[2].response_model is None assert "LangchainRequest" in router.routes[2].body_field.type_.schema()["title"] + + +def test_langchain_router_enable_llm_cache(chain): + router = LangchainRouter( + langchain_url="/chat", + langchain_object=chain, + streaming_mode=0, + llm_cache_mode=1, + ) + + assert router.llm_cache_mode == LLMCacheMode.IN_MEMORY + + import langchain + + assert langchain.llm_cache is not None diff --git a/tests/test_register.py b/tests/test_register.py index 630f9b6..abb118c 100644 --- a/tests/test_register.py +++ b/tests/test_register.py @@ -5,10 +5,10 @@ from lanarky.register import ( STREAMING_CALLBACKS, WEBSOCKET_CALLBACKS, - register, register_streaming_callback, register_websocket_callback, ) +from lanarky.register.base import register @pytest.fixture