From 57f6a11e58a256ad7fc6da091e7a0673c2fba289 Mon Sep 17 00:00:00 2001 From: swissnyf Date: Sat, 24 Feb 2024 14:44:16 +0530 Subject: [PATCH] fixed topgun --- pipeline/top_gun.py | 104 +++++++++++++++++++++++++++++--------------- 1 file changed, 69 insertions(+), 35 deletions(-) diff --git a/pipeline/top_gun.py b/pipeline/top_gun.py index 22b7607..c1014d5 100644 --- a/pipeline/top_gun.py +++ b/pipeline/top_gun.py @@ -172,17 +172,24 @@ class TopGun(BasePipeline): tool_defs = [] tool_dict = {} - def __init__(self, filter_method, llm): + def __init__(self, filter_method, llm, codesynth_cache=False): self.filter_method = filter_method self.llm = llm + self.codesynth_cache = codesynth_cache + + def save_cache(self): + self.code_synth.save_cache() def set_tools(self, tool_descs:List[str], tool_names:List[str]): for tool_desc, tool_name in tqdm(zip(tool_descs,tool_names)): self.add_tool(tool_desc, tool_name) - def init_tools(self): + def init_tools(self, mapping=None): exec_string_tool_def = "\n\n\n".join(self._topgun_corpus) - exec(f"""{exec_string_tool_def}""") + if mapping: + for k,v in mapping.items(): + exec_string_tool_def = exec_string_tool_def.replace(v,k) + exec(f"""{exec_string_tool_def}""", globals()) def __set_tools(self, tool_defs:List[str] ): self.tool_defs = tool_defs @@ -193,12 +200,10 @@ def add_tool(self, new_tool_desc, tool_name, max_retries=10): retries = 0 while not completed and retries