Skip to content

Commit

Permalink
fixed topgun
Browse files Browse the repository at this point in the history
  • Loading branch information
swissnyf committed Feb 24, 2024
1 parent 54c47fd commit 57f6a11
Showing 1 changed file with 69 additions and 35 deletions.
104 changes: 69 additions & 35 deletions pipeline/top_gun.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,24 @@ class TopGun(BasePipeline):
tool_defs = []
tool_dict = {}

def __init__(self, filter_method, llm):
def __init__(self, filter_method, llm, codesynth_cache=False):
self.filter_method = filter_method
self.llm = llm
self.codesynth_cache = codesynth_cache

def save_cache(self):
self.code_synth.save_cache()

def set_tools(self, tool_descs:List[str], tool_names:List[str]):
for tool_desc, tool_name in tqdm(zip(tool_descs,tool_names)):
self.add_tool(tool_desc, tool_name)

def init_tools(self):
def init_tools(self, mapping=None):
exec_string_tool_def = "\n\n\n".join(self._topgun_corpus)
exec(f"""{exec_string_tool_def}""")
if mapping:
for k,v in mapping.items():
exec_string_tool_def = exec_string_tool_def.replace(v,k)
exec(f"""{exec_string_tool_def}""", globals())

def __set_tools(self, tool_defs:List[str] ):
self.tool_defs = tool_defs
Expand All @@ -193,12 +200,10 @@ def add_tool(self, new_tool_desc, tool_name, max_retries=10):
retries = 0
while not completed and retries<max_retries:
try:
tool_def = self.code_synth.forward(tool_name, new_tool_desc, self.llm)
tool_def = self.code_synth.forward(tool_name, new_tool_desc, self.llm, cache=self.codesynth_cache)
formated_tool_def = self.format_tool_def(tool_def, tool_name)
completed = True
except Exception as e:
# if(retries == 1):
# print(tool_name)
print(f"Retrying with new tool def for {tool_name}, {e}")
retries+=1
try:
Expand All @@ -221,44 +226,73 @@ def format_tool_def(self, new_tool_def, tool_name):
new_tool_def = f"{new_tool_def[:idx-1]}@update_traverse_dict\n{new_tool_def[idx:]}"
return new_tool_def

def parser(self, response):

def parser(self, response, mapping = None):
pattern = r'```python\n(.*?)\n```'
matches = re.findall(pattern, response, re.DOTALL)
if True:
global function_tree, traverse_dict
traverse_dict = {}
function_tree = []
if matches:
response = matches[-1]
corpus_lib = "\n\n\n".join(self._topgun_corpus)
exec(f"""{corpus_lib}\n\n\n{response}""")
answer, prev_list = parse_output(traverse_dict, function_tree)
return answer
try:
if True:
global function_tree, traverse_dict
traverse_dict = {}
function_tree = []
if matches:
response = matches[-1]

corpus_lib = "\n\n\n".join(self._topgun_corpus)
# print(f"""{corpus_lib}\n\n\n{response}""")

if mapping is None:
exec(f"""{corpus_lib}\n\n\n{response}""", globals())
else:
response_replaced = response
corpus_lib_replaced = corpus_lib
for k,v in mapping.items():
response_replaced = response_replaced.replace(v,k)
corpus_lib_replaced = corpus_lib_replaced.replace(v,k)

exec(f"""{corpus_lib_replaced }\n\n{response_replaced}""", globals())

answer, prev_list = parse_output(traverse_dict, function_tree)
return answer

except Exception as ex:
temp_out = StringIO()
print(traceback.format_exc())
# print("\nException:",e, "\n")
# self.feedback = traceback.format_exc()
self.feedback = temp_out.getvalue()
return None

return response


def query(self, query, max_retries=10):
def query(self, query, mapping=None, max_retries=5):


# possible_tools_names = self.filter_method.filter(query)
# filtered_tools_corpus = [self.tool_dict[t] for t in possible_tools_names]
possible_tools_names = list(self.tool_dict.keys())
filtered_tools_corpus = [self.tool_dict[t] for t in possible_tools_names]
query_code_interpreter = self.QUERY_CODE_INTERPRETER_PROMPT.format("\n\n\n".join(filtered_tools_corpus).replace("@update_traverse_dict", ""), query)
query_code_interpreter = self.QUERY_CODE_INTERPRETER_PROMPT.format("\n\n\n".join(self._topgun_corpus).replace("@update_traverse_dict", ""), query)

completed = False
retries = 0
response_final = ''

response = self.llm.complete(query_code_interpreter)
yield response.text
parsed_response = self.parser(response.text, mapping)

while not completed and retries<max_retries:
try:
response = self.llm.complete(query_code_interpreter)
print(response)
if parsed_response is None:
retries+=1
# print(response.text, "\n")
response = self.llm.complete(self.REFLEXION_PROMPT.format(query_code_interpreter, response.text, self.feedback))
yield response.text
response_final = response.text
response = self.parser(response.text)
yield f"\n```javascript\n{json.dumps(response, indent=4)}\n```\n"
print(json.dumps(response, indent=4))
print("-------------------------------------------------------------------")
parsed_response = self.parser(response.text, mapping)
yield f"\n```javascript\n{json.dumps(parsed_response, indent=4)}\n```\n"
else:
yield f"\n```javascript\n{json.dumps(parsed_response, indent=4)}\n```\n"
completed = True
except Exception as e:
print(f"Retrying with new plan, exception: {e}")
retries+=1
return json.dumps(response_final, indent=4)
self.feedback = None

if completed:
return json.dumps(parsed_response, indent=4) # ToDO: Check
else:
return response.text

0 comments on commit 57f6a11

Please sign in to comment.