-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
182 lines (146 loc) · 6.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from prompt_list import *
import json
import re
import time
import openai
import os
os.environ.pop("http_proxy", None)
os.environ.pop("all_proxy", None)
os.environ.pop("https_proxy", None)
from openai import OpenAI
import requests
client = OpenAI(
api_key="",
base_url="https://api.openai.com/v1",
)
def retrieve_top_docs(query, docs, model, width=3):
"""
Retrieve the topn most relevant documents for the given query.
Parameters:
- query (str): The input query.
- docs (list of str): The list of documents to search from.
- model_name (str): The name of the SentenceTransformer model to use.
- width (int): The number of top documents to return.
Returns:
- list of float: A list of scores for the topn documents.
- list of str: A list of the topn documents.
"""
query_emb = model.encode(query)
doc_emb = model.encode(docs)
scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
doc_score_pairs = sorted(list(zip(docs, scores)), key=lambda x: x[1], reverse=True)
top_docs = [pair[0] for pair in doc_score_pairs[:width]]
top_scores = [pair[1] for pair in doc_score_pairs[:width]]
return top_docs, top_scores
def compute_bm25_similarity(query, corpus, width=3):
"""
Computes the BM25 similarity between a question and a list of relations,
and returns the topn relations with the highest similarity along with their scores.
Args:
- question (str): Input question.
- relations_list (list): List of relations.
- width (int): Number of top relations to return.
Returns:
- list, list: topn relations with the highest similarity and their respective scores.
"""
tokenized_corpus = [doc.split(" ") for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = query.split(" ")
doc_scores = bm25.get_scores(tokenized_query)
relations = bm25.get_top_n(tokenized_query, corpus, n=width)
doc_scores = sorted(doc_scores, reverse=True)[:width]
return relations, doc_scores
def if_all_zero(topn_scores):
return all(score == 0 for score in topn_scores)
def clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations):
relations = []
if if_all_zero(topn_scores):
topn_scores = [float(1/len(topn_scores))] * len(topn_scores)
i=0
for relation in topn_relations:
if relation in head_relations:
relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": True})
else:
relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": False})
i+=1
return True, relations
def run_llm(prompt, temperature, max_tokens, opeani_api_keys, engine="gpt-3.5-turbo"):
messages = [{"role":"system","content":"You are an AI assistant that helps people find information."}]
message_prompt = {"role":"user","content":prompt}
messages.append(message_prompt)
print("start openai")
f=0
while(f == 0):
try:
response = client.chat.completions.create(
model=engine,
messages = messages,
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0,
presence_penalty=0)
result = response.choices[0].message.content
f = 1
except:
print("openai error, retry")
time.sleep(2)
print("end openai")
return result
def all_unknown_entity(entity_candidates):
return all(candidate == "UnName_Entity" for candidate in entity_candidates)
def del_unknown_entity(entity_candidates):
if len(entity_candidates)==1 and entity_candidates[0]=="UnName_Entity":
return entity_candidates
entity_candidates = [candidate for candidate in entity_candidates if candidate != "UnName_Entity"]
return entity_candidates
def clean_scores(string, entity_candidates):
scores = re.findall(r'\d+\.\d+', string)
scores = [float(number) for number in scores]
if len(scores) == len(entity_candidates):
return scores
else:
print("All entities are created equal.")
return [1/len(entity_candidates)] * len(entity_candidates)
def save_2_jsonl(question, answer, cluster_chain_of_entities, file_name):
dict = {"question":question, "results": answer, "reasoning_chains": cluster_chain_of_entities}
with open("ToG_{}.jsonl".format(file_name), "a") as outfile:
json_str = json.dumps(dict)
outfile.write(json_str + "\n")
def extract_answer(text):
start_index = text.find("{")
end_index = text.find("}")
if start_index != -1 and end_index != -1:
return text[start_index+1:end_index].strip()
else:
return ""
def if_true(prompt):
if prompt.lower().strip().replace(" ", "") == "yes":
return True
return False
def generate_without_explored_paths(question, args):
prompt = cot_prompt + "\n\nQ: " + question + "\nA:"
response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.opeani_api_keys, args.LLM_type)
return response
def if_finish_list(lst):
if all(elem == "[FINISH_ID]" for elem in lst):
return True, []
else:
new_lst = [elem for elem in lst if elem != "[FINISH_ID]"]
return False, new_lst
def construct_relation_prune_prompt(question, entity_name, total_relations, args):
return extract_relation_prompt % (args.width, args.width) + question + '\nTopic Entity: ' + entity_name + '\nRelations: '+ '; '.join(total_relations) + "\nA: "
def construct_entity_score_prompt(question, relation, entity_candidates):
return score_entity_candidates_prompt.format(question, relation) + "; ".join(entity_candidates) + '\nScore: '
def prepare_dataset(dataset_name):
if dataset_name == 'cwq':
with open('./data/cwq.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'question'
elif dataset_name == 'webqsp':
with open('./data/WebQSP.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'RawQuestion'
else:
print("dataset not found, you should pick from {cwq, webqsp, grailqa, simpleqa, qald, webquestions, trex, zeroshotre, creak}.")
exit(-1)
return datas, question_string