Skip to content

Commit

Permalink
feat: improve connection handling (#78)
Browse files Browse the repository at this point in the history
* ➕ add dependencies

* ⚡ remove print statements

* ✨ add openai_aiosession decorator
  • Loading branch information
ajndkr authored May 31, 2023
1 parent 31c3ea8 commit e8e8db5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
36 changes: 35 additions & 1 deletion lanarky/responses/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,46 @@
* `gist@ninely <https://gist.github.com/ninely/88485b2e265d852d3feb8bd115065b1a>`_
* `langchain@#1705 <https://github.com/hwchase17/langchain/discussions/1706>`_
"""
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Union

import aiohttp
from fastapi.responses import StreamingResponse as _StreamingResponse
from langchain.chains.base import Chain
from starlette.background import BackgroundTask
from starlette.types import Send
from starlette.types import Receive, Scope, Send

from lanarky.callbacks import get_streaming_callback, get_streaming_json_callback

logger = logging.getLogger(__name__)


def openai_aiosession(func):
"""Decorator to set openai.aiosession for StreamingResponse."""

@wraps(func)
async def wrapper(*args, **kwargs):
try:
import openai # type: ignore
except ImportError:
raise ImportError(
"openai is not installed. Install it with `pip install 'lanarky[openai]'`."
)

openai.aiosession.set(aiohttp.ClientSession())
logger.info(f"opeanai.aiosession set: {openai.aiosession.get()}")

try:
await func(*args, **kwargs)
finally:
await openai.aiosession.get().close()
logger.info(f"opeanai.aiosession closed: {openai.aiosession.get()}")

return wrapper


# TODO: create OpenAIStreamingResponse for streaming with OpenAI only
class StreamingResponse(_StreamingResponse):
"""StreamingResponse class wrapper for langchain chains."""

Expand Down Expand Up @@ -54,6 +84,10 @@ async def stream_response(self, send: Send) -> None:

await send({"type": "http.response.body", "body": b"", "more_body": False})

@openai_aiosession
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await super().__call__(scope, receive, send)

@staticmethod
def _create_chain_executor(
chain: Chain,
Expand Down
5 changes: 1 addition & 4 deletions lanarky/routing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,8 @@ async def endpoint(
langchain_object: Chain = langchain_dependency,
) -> StreamingResponse:
"""Streaming chat endpoint."""
print(f"langchain_object: {langchain_object}")
inputs = request.dict()
print(f"inputs: {inputs}")
return StreamingResponse.from_chain(
langchain_object, inputs, media_type="text/event-stream"
langchain_object, request.dict(), media_type="text/event-stream"
)

return endpoint
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ python-dotenv = "^1.0.0"
typing-extensions = "4.5.0" # added due to https://github.com/hwchase17/langchain/issues/5113
redis = {version = "^4.5.5", optional = true}
gptcache = {version = "^0.1.28", optional = true}
openai = {version = "^0.27.7", optional = true}
tiktoken = {version = "^0.4.0", optional = true}

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.1"
Expand All @@ -39,6 +41,7 @@ pytest-asyncio = "^0.21.0"
coveralls = "^3.3.1"

[tool.poetry.extras]
openai = ["openai", "tiktoken"]
redis = ["redis"]
gptcache = ["gptcache"]

Expand Down

0 comments on commit e8e8db5

Please sign in to comment.