Skip to content

Commit

Permalink
feat: add LLM caching (#72)
Browse files Browse the repository at this point in the history
* ✨ add llm cache to LangchainRouter

* ➕ add `cache` extras

* 📝 update README

* ♻️ refactor `register` module

* ✅ update unit tests

* 📝 update docs

* ➕ add `gptcache` extras

* ✨ add gptcache integration

* 📝 update docs

* 📝 update README
  • Loading branch information
ajndkr authored May 29, 2023

Verified

This commit was signed with the committer’s verified signature.
rbioteau Romain Bioteau
1 parent c157274 commit ee91409
Showing 13 changed files with 275 additions and 43 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -75,10 +75,10 @@ before running the examples.

- [x] Add support for [LangChain](https://github.com/hwchase17/langchain)
- [x] Add [Gradio](https://github.com/gradio-app/gradio) UI for fast prototyping
- [ ] Add SQL database integration
- [ ] Add support for [Guardrails](https://github.com/ShreyaR/guardrails)
- [x] Add support for in-memory, Redis and [GPTCache](https://github.com/zilliztech/GPTCache) LLM caching
- [ ] Add support for [LlamaIndex](https://github.com/jerryjliu/llama_index)
- [ ] Add [GPTCache](https://github.com/zilliztech/GPTCache) integration
- [ ] Add SQL database integration
- [ ] Add support for [Rebuff](https://github.com/woop/rebuff)

## 🤩 Stargazers

19 changes: 19 additions & 0 deletions docs/lanarky/lanarky.register.rst
Original file line number Diff line number Diff line change
@@ -5,3 +5,22 @@ registry
:members:
:undoc-members:
:show-inheritance:

Submodules
----------

lanarky.register.base module
-----------------------------

.. automodule:: lanarky.register.base
:members:
:undoc-members:
:show-inheritance:

lanarky.register.callbacks module
----------------------------

.. automodule:: lanarky.register.callbacks
:members:
:undoc-members:
:show-inheritance:
105 changes: 105 additions & 0 deletions docs/langchain/cache.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
Deploy Applications with LLM Caching
=====================================

Langchain offers multiple LLM cache solutions. Reference: `How to cache LLM calls <https://python.langchain.com/en/latest/modules/models/llms/examples/llm_caching.html>`_

To simplify the use of these solutions for Lanarky users, the ``LangchainRouter`` can be used to setup LLM caching for your application.

We'll use a simple ``LLMChain`` application as an example:

.. code-block:: python
from dotenv import load_dotenv
from fastapi import FastAPI
from langchain import LLMChain
from langchain.llms import OpenAI
from lanarky.routing import LangchainRouter
load_dotenv()
app = FastAPI()
langchain_router = LangchainRouter(
langchain_url="/chat",
langchain_object=LLMChain.from_string(
llm=OpenAI(temperature=0), template="Answer the query.\n{query}"
),
streaming_mode=0,
)
app.include_router(langchain_router)
The ``LangchainRouter`` class uses the ``llm_cache_mode`` parameter to setup LLM caching.
There are three available modes:

- ``llm_cache_mode=0``: No LLM caching
- ``llm_cache_mode=1``: In-memory LLM caching
- ``llm_cache_mode=2``: Redis LLM caching
- ``llm_cache_mode=3``: GPTCache LLM caching

In-Memory Caching
-----------------

To setup in-memory caching, use the following ``LangchainRouter`` configuration:

.. code-block:: python
langchain_router = LangchainRouter(
langchain_url="/chat",
langchain_object=LLMChain.from_string(
llm=OpenAI(temperature=0), template="Answer the query.\n{query}"
),
streaming_mode=0,
llm_cache_mode=1,
)
Redis Caching
-------------

To setup Redis caching, first install the required dependencies:

.. code-block:: bash
pip install "lanarky[redis]"
Next, setup a Redis server. We recommend using Docker:

.. code-block:: bash
docker run --name redis -p 6379:6379 -d redis
Finally, use the following ``LangchainRouter`` configuration:

.. code-block:: python
langchain_router = LangchainRouter(
langchain_url="/chat",
langchain_object=LLMChain.from_string(
llm=OpenAI(temperature=0), template="Answer the query.\n{query}"
),
streaming_mode=0,
llm_cache_mode=2,
llm_cache_kwargs={"url": "redis://localhost:6379/"},
)
GPTCache Caching
----------------

To setup GPTCache caching, first install the required dependencies:

.. code-block:: bash
pip install "lanarky[gptcache]"
Then, use the following ``LangchainRouter`` configuration:

.. code-block:: python
langchain_router = LangchainRouter(
langchain_url="/chat",
langchain_object=LLMChain.from_string(
llm=OpenAI(temperature=0), template="Answer the query.\n{query}"
),
streaming_mode=0,
llm_cache_mode=3,
)
1 change: 1 addition & 0 deletions docs/langchain/index.rst
Original file line number Diff line number Diff line change
@@ -12,4 +12,5 @@ your Langchain application.
:maxdepth: 1

deploy
cache
custom_callbacks
17 changes: 17 additions & 0 deletions lanarky/register/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .callbacks import (
STREAMING_CALLBACKS,
STREAMING_JSON_CALLBACKS,
WEBSOCKET_CALLBACKS,
register_streaming_callback,
register_streaming_json_callback,
register_websocket_callback,
)

__all__ = [
"STREAMING_CALLBACKS",
"STREAMING_JSON_CALLBACKS",
"WEBSOCKET_CALLBACKS",
"register_streaming_callback",
"register_streaming_json_callback",
"register_websocket_callback",
]
37 changes: 1 addition & 36 deletions lanarky/register.py → lanarky/register/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Union
from typing import Any, Optional, Union


def register(
@@ -28,38 +28,3 @@ def _register_cls(cls: Any, required_kwargs: Optional[list] = None) -> Any:
return cls

return _register_cls


STREAMING_CALLBACKS: dict[str, Any] = {}
WEBSOCKET_CALLBACKS: dict[str, Any] = {}
STREAMING_JSON_CALLBACKS: dict[str, Any] = {}


def register_streaming_callback(key: Union[list[str], str]) -> Callable:
"""Register a streaming callback handler."""

def _register_cls(cls: Any) -> Callable:
register(key, STREAMING_CALLBACKS)(cls=cls)
return cls

return _register_cls


def register_websocket_callback(key: Union[list[str], str]) -> Callable:
"""Register a websocket callback handler."""

def _register_cls(cls: Any) -> Callable:
register(key, WEBSOCKET_CALLBACKS)(cls=cls)
return cls

return _register_cls


def register_streaming_json_callback(key: Union[list[str], str]) -> Callable:
"""Register an streaming json callback handler."""

def _register_cls(cls: Any) -> Callable:
register(key, STREAMING_JSON_CALLBACKS)(cls=cls)
return cls

return _register_cls
37 changes: 37 additions & 0 deletions lanarky/register/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any, Callable, Union

from .base import register

STREAMING_CALLBACKS: dict[str, Any] = {}
WEBSOCKET_CALLBACKS: dict[str, Any] = {}
STREAMING_JSON_CALLBACKS: dict[str, Any] = {}


def register_streaming_callback(key: Union[list[str], str]) -> Callable:
"""Register a streaming callback handler."""

def _register_cls(cls: Any) -> Callable:
register(key, STREAMING_CALLBACKS)(cls=cls)
return cls

return _register_cls


def register_websocket_callback(key: Union[list[str], str]) -> Callable:
"""Register a websocket callback handler."""

def _register_cls(cls: Any) -> Callable:
register(key, WEBSOCKET_CALLBACKS)(cls=cls)
return cls

return _register_cls


def register_streaming_json_callback(key: Union[list[str], str]) -> Callable:
"""Register an streaming json callback handler."""

def _register_cls(cls: Any) -> Callable:
register(key, STREAMING_JSON_CALLBACKS)(cls=cls)
return cls

return _register_cls
3 changes: 2 additions & 1 deletion lanarky/routing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .langchain import LangchainRouter
from .utils import LLMCacheMode, StreamingMode

__all__ = ["LangchainRouter"]
__all__ = ["LangchainRouter", "StreamingMode", "LLMCacheMode"]
59 changes: 58 additions & 1 deletion lanarky/routing/langchain.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
from langchain.chains.base import Chain

from .utils import (
LLMCacheMode,
StreamingMode,
create_langchain_dependency,
create_langchain_endpoint,
@@ -24,6 +25,8 @@ def __init__(
langchain_object: Optional[Type[Chain]] = None,
langchain_endpoint_kwargs: Optional[dict[str, Any]] = None,
streaming_mode: Optional[StreamingMode] = None,
llm_cache_mode: Optional[LLMCacheMode] = None,
llm_cache_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -32,21 +35,75 @@ def __init__(
self.langchain_object = langchain_object
self.langchain_endpoint_kwargs = langchain_endpoint_kwargs or {}
self.streaming_mode = streaming_mode
self.llm_cache_mode = llm_cache_mode
self.llm_cache_kwargs = llm_cache_kwargs or {}

self.langchain_dependencies = []

self.setup()

def setup(self) -> None:
"""Sets up the Langchain router."""
if self.langchain_url:
if self.langchain_url is not None:
self.add_langchain_api_route(
self.langchain_url,
self.langchain_object,
self.streaming_mode,
**self.langchain_endpoint_kwargs,
)

if self.llm_cache_mode is not None:
self.setup_llm_cache()

def setup_llm_cache(self) -> None:
"""Sets up the LLM cache."""
import langchain

if self.llm_cache_mode == LLMCacheMode.IN_MEMORY:
from langchain.cache import InMemoryCache

langchain.llm_cache = InMemoryCache()

elif self.llm_cache_mode == LLMCacheMode.REDIS:
try:
from redis import Redis # type: ignore
except ImportError:
raise ImportError(
"""Redis is not installed. Install it with `pip install "lanarky[redis]"`."""
)
from langchain.cache import RedisCache

langchain.llm_cache = RedisCache(
redis_=Redis.from_url(**self.llm_cache_kwargs)
)

elif self.llm_cache_mode == LLMCacheMode.GPTCACHE:
try:
from gptcache import Cache # type: ignore
from gptcache.manager.factory import manager_factory # type: ignore
from gptcache.processor.pre import get_prompt # type: ignore
except ImportError:
raise ImportError(
"""GPTCache is not installed. Install it with `pip install "lanarky[gptcache]"`."""
)
import hashlib

from langchain.cache import GPTCache

def init_gptcache(cache_obj: Cache, llm: str):
hashed_llm = hashlib.sha256(llm.encode()).hexdigest()
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=manager_factory(
manager="map", data_dir=f"map_cache_{hashed_llm}"
),
)

langchain.llm_cache = GPTCache(init_gptcache)

else:
raise ValueError(f"Invalid LLM cache mode: {self.llm_cache_mode}")

def add_langchain_api_route(
self,
url: str,
7 changes: 7 additions & 0 deletions lanarky/routing/utils.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,13 @@ class StreamingMode(IntEnum):
JSON = 2


class LLMCacheMode(IntEnum):
OFF = 0
IN_MEMORY = 1
REDIS = 2
GPTCACHE = 3


def create_langchain_dependency(langchain_object: Type[Chain]) -> params.Depends:
"""Creates a langchain object dependency."""

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -16,6 +16,8 @@ langchain = ">=0.0.183"
urllib3 = "<=1.26.15" # added due to poetry errors
python-dotenv = "^1.0.0"
typing-extensions = "4.5.0" # added due to https://github.com/hwchase17/langchain/issues/5113
redis = {version = "^4.5.5", optional = true}
gptcache = {version = "^0.1.28", optional = true}

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.1"
@@ -36,6 +38,10 @@ pytest-cov = "^4.0.0"
pytest-asyncio = "^0.21.0"
coveralls = "^3.3.1"

[tool.poetry.extras]
redis = ["redis"]
gptcache = ["gptcache"]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Loading

0 comments on commit ee91409

Please sign in to comment.