-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: add
name_prefix
for creating dynamic request and response mode…
…ls (#70) * ⚡ use name_prefix to create different pydantic models * ✏️ fix docs * 🔨 update examples - 📝 update examples readme - ⬆️ bump dependencies * ✅ fix unit test
- Loading branch information
Showing
11 changed files
with
105 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,41 @@ | ||
from functools import lru_cache | ||
from typing import Callable | ||
|
||
from dotenv import load_dotenv | ||
from fastapi import Depends, FastAPI, Request, WebSocket | ||
from fastapi import FastAPI, Request | ||
from fastapi.templating import Jinja2Templates | ||
from langchain import ConversationChain | ||
from langchain.chat_models import ChatOpenAI | ||
from pydantic import BaseModel | ||
|
||
from lanarky.responses import StreamingResponse | ||
from lanarky.routing import LangchainRouter | ||
from lanarky.testing import mount_gradio_app | ||
from lanarky.websockets import WebsocketConnection | ||
|
||
load_dotenv() | ||
|
||
app = mount_gradio_app(FastAPI(title="ConversationChainDemo")) | ||
templates = Jinja2Templates(directory="templates") | ||
|
||
|
||
class QueryRequest(BaseModel): | ||
query: str | ||
|
||
|
||
def conversation_chain_dependency() -> Callable[[], ConversationChain]: | ||
@lru_cache(maxsize=1) | ||
def dependency() -> ConversationChain: | ||
return ConversationChain( | ||
llm=ChatOpenAI( | ||
temperature=0, | ||
streaming=True, | ||
), | ||
verbose=True, | ||
) | ||
|
||
return dependency | ||
|
||
|
||
conversation_chain = conversation_chain_dependency() | ||
|
||
|
||
@app.post("/chat") | ||
async def chat( | ||
request: QueryRequest, | ||
chain: ConversationChain = Depends(conversation_chain), | ||
) -> StreamingResponse: | ||
return StreamingResponse.from_chain( | ||
chain, request.query, media_type="text/event-stream" | ||
def create_chain(): | ||
return ConversationChain( | ||
llm=ChatOpenAI( | ||
temperature=0, | ||
streaming=True, | ||
), | ||
verbose=True, | ||
) | ||
|
||
|
||
@app.post("/chat_json") | ||
async def chat_json( | ||
request: QueryRequest, | ||
chain: ConversationChain = Depends(conversation_chain), | ||
) -> StreamingResponse: | ||
return StreamingResponse.from_chain( | ||
chain, request.query, as_json=True, media_type="text/event-stream" | ||
) | ||
app = mount_gradio_app(FastAPI(title="ConversationChainDemo")) | ||
templates = Jinja2Templates(directory="templates") | ||
chain = create_chain() | ||
|
||
|
||
@app.get("/") | ||
async def get(request: Request): | ||
return templates.TemplateResponse("index.html", {"request": request}) | ||
|
||
|
||
@app.websocket("/ws") | ||
async def websocket_endpoint( | ||
websocket: WebSocket, chain: ConversationChain = Depends(conversation_chain) | ||
): | ||
connection = WebsocketConnection.from_chain(chain=chain, websocket=websocket) | ||
await connection.connect() | ||
langchain_router = LangchainRouter( | ||
langchain_url="/chat", langchain_object=chain, streaming_mode=1 | ||
) | ||
langchain_router.add_langchain_api_route( | ||
"/chat_json", langchain_object=chain, streaming_mode=2 | ||
) | ||
langchain_router.add_langchain_api_websocket_route("/ws", langchain_object=chain) | ||
|
||
app.include_router(langchain_router) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,86 +1,52 @@ | ||
from functools import lru_cache | ||
from typing import Callable | ||
|
||
from dotenv import load_dotenv | ||
from fastapi import Depends, FastAPI, Request, WebSocket | ||
from fastapi import FastAPI, Request | ||
from fastapi.templating import Jinja2Templates | ||
from langchain.chains import RetrievalQAWithSourcesChain | ||
from langchain.chat_models import ChatOpenAI | ||
from pydantic import BaseModel | ||
from langchain.embeddings import OpenAIEmbeddings | ||
from langchain.vectorstores import FAISS | ||
|
||
from lanarky.responses import StreamingResponse | ||
from lanarky.routing import LangchainRouter | ||
from lanarky.testing import mount_gradio_app | ||
from lanarky.websockets import WebsocketConnection | ||
|
||
load_dotenv() | ||
|
||
app = mount_gradio_app(FastAPI(title="RetrievalQAWithSourcesChainDemo")) | ||
|
||
templates = Jinja2Templates(directory="templates") | ||
|
||
|
||
class QueryRequest(BaseModel): | ||
query: str | ||
|
||
|
||
def retrieval_qa_chain_dependency() -> Callable[[], RetrievalQAWithSourcesChain]: | ||
@lru_cache(maxsize=1) | ||
def dependency() -> RetrievalQAWithSourcesChain: | ||
from langchain.embeddings import OpenAIEmbeddings | ||
from langchain.vectorstores import FAISS | ||
|
||
db = FAISS.load_local( | ||
folder_path="vector_stores/", | ||
index_name="langchain-python", | ||
embeddings=OpenAIEmbeddings(), | ||
) | ||
|
||
return RetrievalQAWithSourcesChain.from_chain_type( | ||
llm=ChatOpenAI( | ||
temperature=0, | ||
streaming=True, | ||
), | ||
chain_type="stuff", | ||
retriever=db.as_retriever(), | ||
return_source_documents=True, | ||
verbose=True, | ||
) | ||
|
||
return dependency | ||
|
||
|
||
retrieval_qa_chain = retrieval_qa_chain_dependency() | ||
|
||
def create_chain(): | ||
db = FAISS.load_local( | ||
folder_path="vector_stores/", | ||
index_name="langchain-python", | ||
embeddings=OpenAIEmbeddings(), | ||
) | ||
|
||
@app.post("/chat") | ||
async def chat( | ||
request: QueryRequest, | ||
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain), | ||
) -> StreamingResponse: | ||
return StreamingResponse.from_chain( | ||
chain, request.query, media_type="text/event-stream" | ||
return RetrievalQAWithSourcesChain.from_chain_type( | ||
llm=ChatOpenAI( | ||
temperature=0, | ||
streaming=True, | ||
), | ||
chain_type="stuff", | ||
retriever=db.as_retriever(), | ||
return_source_documents=True, | ||
verbose=True, | ||
) | ||
|
||
|
||
@app.post("/chat_json") | ||
async def chat_json( | ||
request: QueryRequest, | ||
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain), | ||
) -> StreamingResponse: | ||
return StreamingResponse.from_chain( | ||
chain, request.query, as_json=True, media_type="text/event-stream" | ||
) | ||
app = mount_gradio_app(FastAPI(title="RetrievalQAWithSourcesChainDemo")) | ||
templates = Jinja2Templates(directory="templates") | ||
chain = create_chain() | ||
|
||
|
||
@app.get("/") | ||
async def get(request: Request): | ||
return templates.TemplateResponse("index.html", {"request": request}) | ||
|
||
|
||
@app.websocket("/ws") | ||
async def websocket_endpoint( | ||
websocket: WebSocket, | ||
chain: RetrievalQAWithSourcesChain = Depends(retrieval_qa_chain), | ||
): | ||
connection = WebsocketConnection.from_chain(chain=chain, websocket=websocket) | ||
await connection.connect() | ||
langchain_router = LangchainRouter( | ||
langchain_url="/chat", langchain_object=chain, streaming_mode=1 | ||
) | ||
langchain_router.add_langchain_api_route( | ||
"/chat_json", langchain_object=chain, streaming_mode=2 | ||
) | ||
langchain_router.add_langchain_api_websocket_route("/ws", langchain_object=chain) | ||
|
||
app.include_router(langchain_router) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,42 @@ | ||
from functools import lru_cache | ||
from typing import Callable | ||
|
||
from dotenv import load_dotenv | ||
from fastapi import Depends, FastAPI, Request, WebSocket | ||
from fastapi import FastAPI, Request | ||
from fastapi.templating import Jinja2Templates | ||
from langchain.agents import AgentExecutor, AgentType, initialize_agent, load_tools | ||
from langchain.chat_models import ChatOpenAI | ||
from pydantic import BaseModel | ||
|
||
from lanarky.responses import StreamingResponse | ||
from lanarky.routing import LangchainRouter | ||
from lanarky.testing import mount_gradio_app | ||
from lanarky.websockets import WebsocketConnection | ||
|
||
load_dotenv() | ||
|
||
app = mount_gradio_app(FastAPI(title="ZeroShotAgentDemo")) | ||
templates = Jinja2Templates(directory="templates") | ||
|
||
|
||
class QueryRequest(BaseModel): | ||
query: str | ||
|
||
|
||
def zero_shot_agent_dependency() -> Callable[[], AgentExecutor]: | ||
@lru_cache(maxsize=1) | ||
def dependency() -> AgentExecutor: | ||
llm = ChatOpenAI( | ||
temperature=0, | ||
streaming=True, | ||
) | ||
tools = load_tools(["llm-math"], llm=llm) | ||
agent = initialize_agent( | ||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True | ||
) | ||
return agent | ||
|
||
return dependency | ||
|
||
|
||
zero_shot_agent = zero_shot_agent_dependency() | ||
|
||
|
||
@app.post("/chat") | ||
async def chat( | ||
request: QueryRequest, | ||
agent: AgentExecutor = Depends(zero_shot_agent), | ||
) -> StreamingResponse: | ||
return StreamingResponse.from_chain( | ||
agent, request.query, media_type="text/event-stream" | ||
def create_chain() -> AgentExecutor: | ||
llm = ChatOpenAI( | ||
temperature=0, | ||
streaming=True, | ||
) | ||
tools = load_tools(["llm-math"], llm=llm) | ||
return initialize_agent( | ||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True | ||
) | ||
|
||
|
||
@app.post("/chat_json") | ||
async def chat_json( | ||
request: QueryRequest, | ||
agent: AgentExecutor = Depends(zero_shot_agent), | ||
) -> StreamingResponse: | ||
return StreamingResponse.from_chain( | ||
agent, request.query, as_json=True, media_type="text/event-stream" | ||
) | ||
app = mount_gradio_app(FastAPI(title="ZeroShotAgentDemo")) | ||
templates = Jinja2Templates(directory="templates") | ||
chain = create_chain() | ||
|
||
|
||
@app.get("/") | ||
async def get(request: Request): | ||
return templates.TemplateResponse("index.html", {"request": request}) | ||
|
||
|
||
@app.websocket("/ws") | ||
async def websocket_endpoint( | ||
websocket: WebSocket, agent: AgentExecutor = Depends(zero_shot_agent) | ||
): | ||
connection = WebsocketConnection.from_chain(chain=agent, websocket=websocket) | ||
await connection.connect() | ||
langchain_router = LangchainRouter( | ||
langchain_url="/chat", langchain_object=chain, streaming_mode=1 | ||
) | ||
langchain_router.add_langchain_api_route( | ||
"/chat_json", langchain_object=chain, streaming_mode=2 | ||
) | ||
langchain_router.add_langchain_api_websocket_route("/ws", langchain_object=chain) | ||
|
||
app.include_router(langchain_router) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.