-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit_app.py
96 lines (70 loc) · 3.13 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import openai
import pinecone
from PIL import Image
import streamlit as st
from sentence_transformers import SentenceTransformer
INDEX_NAME = "shubhams-index"
MODEL_NAME = "multi-qa-MiniLM-L6-cos-v1"
DEFAULT_PROMPT = [{"role": "user", "content": "Use the following context to answer the user query. If the user query is a question, provide an answer using the context. If the user query is a statement or a phrase, provide the best response using the context.\n\nContext:\n\n[CONTEXT]\n\nQuery: [QUERY]\n\nResponse:"}]
@st.cache_resource
def init_pinecone():
pinecone.init(
api_key=st.secrets["PINECONE_API_KEY"],
environment="us-west4-gcp"
)
@st.cache_resource
def init_openai():
openai.api_key = st.secrets["OPENAI_API_KEY"]
@st.cache_resource
def load_model(model_name):
model = SentenceTransformer(model_name)
return model
model = load_model(MODEL_NAME)
@st.cache_data
def query_index(query, episode_title=None, num_results=5):
init_pinecone()
index = pinecone.Index(INDEX_NAME)
query_embedding = model.encode(query, show_progress_bar=False).tolist()
metadata_filter = {"episode_title": {"$eq": episode_title}} if episode_title else None
results = index.query(query_embedding, top_k=num_results, include_metadata=True, filter=metadata_filter)
return [(
match['metadata']['episode_number'],
match['metadata']['episode_title'],
match['metadata']['text']
) for match in results['matches']]
@st.cache_data
def format_prompt(results, query):
context = ""
for episode_number, episode_title, text in results:
context += f"{episode_number} ({episode_title}): {text}\n\n"
prompt = [message.copy() for message in DEFAULT_PROMPT]
prompt[-1]["content"] = prompt[-1]["content"].replace("[CONTEXT]", context)
prompt[-1]["content"] = prompt[-1]["content"].replace("[QUERY]", query)
return prompt
@st.cache_data
def get_model_response(prompt):
init_openai()
response = openai.ChatCompletion.create(
model='gpt-3.5-turbo',
messages=prompt,
temperature=0.7,
max_tokens=256,
top_p=1,
frequency_penalty=0,
presence_penalty=0
)
return response['choices'][0]['message']['content'].strip()
st.title("Office Ladies Podcast Search")
image = Image.open('office-ladies-podcast-image.jpeg')
st.image(image)
st.markdown('<p style="text-align: left;">By Shubham Pawar<span style="float:right;"><a href="https://medium.com/@pawarshubham28794/searching-for-more-enhancing-the-office-ladies-podcast-with-semantic-search-14d914e04b5d">How was this app made?</a></span></p>', unsafe_allow_html=True)
query = st.text_input('Enter your search query:', placeholder='Search for a phrase or an answer in the podcast transcripts', label_visibility='hidden')
if st.button('Search', type='primary'):
results = query_index(query)
prompt = format_prompt(results, query)
response = get_model_response(prompt)
st.markdown(f":green[{response}]")
for episode_number, episode_title, text in results:
st.subheader(f":blue[{episode_number}: {episode_title}]")
st.write(text)