Skip to content

Commit

Permalink
fix: add name_prefix for creating dynamic request and response mode…
Browse files Browse the repository at this point in the history
…ls (#70)

* ⚡ use name_prefix to create different pydantic models

* ✏️ fix docs

* 🔨 update examples

- 📝 update examples readme
- ⬆️ bump dependencies

* ✅ fix unit test
  • Loading branch information
ajndkr authored May 29, 2023
1 parent 6ecc1af commit 8468957
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 189 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ load_dotenv()
app = FastAPI()

langchain_router = LangchainRouter(
url="/chat",
langchain_url="/chat",
langchain_object=ConversationChain(
llm=ChatOpenAI(temperature=0), verbose=True
),
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ You can get quickly started with Lanarky and deploy your first Langchain app in
app = FastAPI()
langchain_router = LangchainRouter(
url="/chat",
langchain_url="/chat",
langchain_object=ConversationChain(
llm=ChatOpenAI(temperature=0), verbose=True
),
Expand Down
4 changes: 2 additions & 2 deletions docs/langchain/deploy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ To better understand ``LangchainRouter``, let's break down the example below:
app = FastAPI()
langchain_router = LangchainRouter(
url="/chat",
langchain_url="/chat",
langchain_object=ConversationChain(
llm=ChatOpenAI(temperature=0),
verbose=True
Expand All @@ -46,7 +46,7 @@ Here's an example:
app = FastAPI()
langchain_router = LangchainRouter(
url="/chat",
langchain_url="/chat",
langchain_object=ConversationChain(
llm=ChatOpenAI(temperature=0, streaming=True),
verbose=True
Expand Down
6 changes: 3 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ uvicorn app.conversation_chain:app --reload
```bash
curl -N -X POST \
-H "Accept: text/event-stream" -H "Content-Type: application/json" \
-d '{"query": "write me a song about sparkling water" }' \
-d '{"input": "write me a song about sparkling water" }' \
http://localhost:8000/chat
```

Expand All @@ -92,7 +92,7 @@ uvicorn app.retrieval_qa_w_sources:app --reload
```bash
curl -N -X POST \
-H "Accept: text/event-stream" -H "Content-Type: application/json" \
-d '{"query": "Give me list of text splitters available with code samples" }' \
-d '{"question": "Give me list of text splitters available with code samples" }' \
http://localhost:8000/chat
```

Expand Down Expand Up @@ -138,6 +138,6 @@ uvicorn app.zero_shot_agent:app --reload
```bash
curl -N -X POST \
-H "Accept: text/event-stream" -H "Content-Type: application/json" \
-d '{"query": "what is the square root of 64?" }' \
-d '{"input": "what is the square root of 64?" }' \
http://localhost:8000/chat
```
73 changes: 21 additions & 52 deletions examples/app/conversation_chain.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,41 @@
from functools import lru_cache
from typing import Callable

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Request, WebSocket
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from langchain import ConversationChain
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel

from lanarky.responses import StreamingResponse
from lanarky.routing import LangchainRouter
from lanarky.testing import mount_gradio_app
from lanarky.websockets import WebsocketConnection

load_dotenv()

app = mount_gradio_app(FastAPI(title="ConversationChainDemo"))
templates = Jinja2Templates(directory="templates")


class QueryRequest(BaseModel):
query: str


def conversation_chain_dependency() -> Callable[[], ConversationChain]:
@lru_cache(maxsize=1)
def dependency() -> ConversationChain:
return ConversationChain(
llm=ChatOpenAI(
temperature=0,
streaming=True,
),
verbose=True,
)

return dependency


conversation_chain = conversation_chain_dependency()


@app.post("/chat")
async def chat(
request: QueryRequest,
chain: ConversationChain = Depends(conversation_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, media_type="text/event-stream"
def create_chain():
return ConversationChain(
llm=ChatOpenAI(
temperature=0,
streaming=True,
),
verbose=True,
)


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
chain: ConversationChain = Depends(conversation_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, as_json=True, media_type="text/event-stream"
)
app = mount_gradio_app(FastAPI(title="ConversationChainDemo"))
templates = Jinja2Templates(directory="templates")
chain = create_chain()


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})


@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket, chain: ConversationChain = Depends(conversation_chain)
):
connection = WebsocketConnection.from_chain(chain=chain, websocket=websocket)
await connection.connect()
langchain_router = LangchainRouter(
langchain_url="/chat", langchain_object=chain, streaming_mode=1
)
langchain_router.add_langchain_api_route(
"/chat_json", langchain_object=chain, streaming_mode=2
)
langchain_router.add_langchain_api_websocket_route("/ws", langchain_object=chain)

app.include_router(langchain_router)
96 changes: 31 additions & 65 deletions examples/app/retrieval_qa_w_sources.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,52 @@
from functools import lru_cache
from typing import Callable

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Request, WebSocket
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

from lanarky.responses import StreamingResponse
from lanarky.routing import LangchainRouter
from lanarky.testing import mount_gradio_app
from lanarky.websockets import WebsocketConnection

load_dotenv()

app = mount_gradio_app(FastAPI(title="RetrievalQAWithSourcesChainDemo"))

templates = Jinja2Templates(directory="templates")


class QueryRequest(BaseModel):
query: str


def retrieval_qa_chain_dependency() -> Callable[[], RetrievalQAWithSourcesChain]:
@lru_cache(maxsize=1)
def dependency() -> RetrievalQAWithSourcesChain:
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

db = FAISS.load_local(
folder_path="vector_stores/",
index_name="langchain-python",
embeddings=OpenAIEmbeddings(),
)

return RetrievalQAWithSourcesChain.from_chain_type(
llm=ChatOpenAI(
temperature=0,
streaming=True,
),
chain_type="stuff",
retriever=db.as_retriever(),
return_source_documents=True,
verbose=True,
)

return dependency


retrieval_qa_chain = retrieval_qa_chain_dependency()

def create_chain():
db = FAISS.load_local(
folder_path="vector_stores/",
index_name="langchain-python",
embeddings=OpenAIEmbeddings(),
)

@app.post("/chat")
async def chat(
request: QueryRequest,
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, media_type="text/event-stream"
return RetrievalQAWithSourcesChain.from_chain_type(
llm=ChatOpenAI(
temperature=0,
streaming=True,
),
chain_type="stuff",
retriever=db.as_retriever(),
return_source_documents=True,
verbose=True,
)


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain),
) -> StreamingResponse:
return StreamingResponse.from_chain(
chain, request.query, as_json=True, media_type="text/event-stream"
)
app = mount_gradio_app(FastAPI(title="RetrievalQAWithSourcesChainDemo"))
templates = Jinja2Templates(directory="templates")
chain = create_chain()


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})


@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain),
):
connection = WebsocketConnection.from_chain(chain=chain, websocket=websocket)
await connection.connect()
langchain_router = LangchainRouter(
langchain_url="/chat", langchain_object=chain, streaming_mode=1
)
langchain_router.add_langchain_api_route(
"/chat_json", langchain_object=chain, streaming_mode=2
)
langchain_router.add_langchain_api_websocket_route("/ws", langchain_object=chain)

app.include_router(langchain_router)
76 changes: 22 additions & 54 deletions examples/app/zero_shot_agent.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,42 @@
from functools import lru_cache
from typing import Callable

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Request, WebSocket
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from langchain.agents import AgentExecutor, AgentType, initialize_agent, load_tools
from langchain.chat_models import ChatOpenAI
from pydantic import BaseModel

from lanarky.responses import StreamingResponse
from lanarky.routing import LangchainRouter
from lanarky.testing import mount_gradio_app
from lanarky.websockets import WebsocketConnection

load_dotenv()

app = mount_gradio_app(FastAPI(title="ZeroShotAgentDemo"))
templates = Jinja2Templates(directory="templates")


class QueryRequest(BaseModel):
query: str


def zero_shot_agent_dependency() -> Callable[[], AgentExecutor]:
@lru_cache(maxsize=1)
def dependency() -> AgentExecutor:
llm = ChatOpenAI(
temperature=0,
streaming=True,
)
tools = load_tools(["llm-math"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
return agent

return dependency


zero_shot_agent = zero_shot_agent_dependency()


@app.post("/chat")
async def chat(
request: QueryRequest,
agent: AgentExecutor = Depends(zero_shot_agent),
) -> StreamingResponse:
return StreamingResponse.from_chain(
agent, request.query, media_type="text/event-stream"
def create_chain() -> AgentExecutor:
llm = ChatOpenAI(
temperature=0,
streaming=True,
)
tools = load_tools(["llm-math"], llm=llm)
return initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)


@app.post("/chat_json")
async def chat_json(
request: QueryRequest,
agent: AgentExecutor = Depends(zero_shot_agent),
) -> StreamingResponse:
return StreamingResponse.from_chain(
agent, request.query, as_json=True, media_type="text/event-stream"
)
app = mount_gradio_app(FastAPI(title="ZeroShotAgentDemo"))
templates = Jinja2Templates(directory="templates")
chain = create_chain()


@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})


@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket, agent: AgentExecutor = Depends(zero_shot_agent)
):
connection = WebsocketConnection.from_chain(chain=agent, websocket=websocket)
await connection.connect()
langchain_router = LangchainRouter(
langchain_url="/chat", langchain_object=chain, streaming_mode=1
)
langchain_router.add_langchain_api_route(
"/chat_json", langchain_object=chain, streaming_mode=2
)
langchain_router.add_langchain_api_websocket_route("/ws", langchain_object=chain)

app.include_router(langchain_router)
2 changes: 1 addition & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ jsonschema==4.17.3
# via altair
kiwisolver==1.4.4
# via matplotlib
lanarky==0.7.0
lanarky==0.7.2
# via -r requirements.in
langchain==0.0.183
# via
Expand Down
Loading

0 comments on commit 8468957

Please sign in to comment.