Skip to content

Commit

Permalink
Update conversation_chain.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MahaleVivek authored Jun 1, 2023
1 parent 51d7595 commit e856c64
Showing 1 changed file with 1 addition and 27 deletions.
28 changes: 1 addition & 27 deletions examples/app/conversation_chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from functools import lru_cache
from typing import Callable


from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Request, WebSocket
from fastapi.templating import Jinja2Templates
Expand All @@ -10,47 +8,35 @@
from pydantic import BaseModel
import uvicorn


from lanarky.responses import StreamingResponse
from lanarky.testing import mount_gradio_app
from lanarky.websockets import WebsocketConnection




load_dotenv()


app = mount_gradio_app(FastAPI(title="ConversationChainDemo"))
templates = Jinja2Templates(directory="templates")




class QueryRequest(BaseModel):
query: str




def conversation_chain_dependency() -> Callable[[], ConversationChain]:
@lru_cache(maxsize=1)
def dependency() -> ConversationChain:


template = """You are a muscle university fitness chatbot having a conversation with a fitness enthusiast who is trying improve his fitness journey.
you can answer the question related to fitness/ workout plans/ diet/ grocery for healthy diet. add bullet points. you can answer the user queries
in Hinglish and use jeet selal's style to respond. restrcit the answers to fitness/ workout plans/ diet/ grocery for healthy diet. if you don't know the answer,
don't try to make up the answer, add detailed answers and add bullet points if needed.
{history}
Human: {input}
Chatbot:"""


PROMPT = PromptTemplate(
input_variables=["input", "history"],
template=template
Expand All @@ -64,18 +50,12 @@ def dependency() -> ConversationChain:
verbose=True,
prompt=PROMPT,
)


return dependency




conversation_chain = conversation_chain_dependency()




@app.post("/chat")
async def chat(
request: QueryRequest,
Expand All @@ -86,15 +66,11 @@ async def chat(
)




@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})




@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket, chain: ConversationChain = Depends(conversation_chain)
Expand All @@ -103,7 +79,5 @@ async def websocket_endpoint(
await connection.connect()




if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)

0 comments on commit e856c64

Please sign in to comment.