Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BFCL] Adding actionGemma model handler #610

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,13 @@
"https://huggingface.co/Salesforce/xLAM-7b-fc-r",
"Salesforce",
"cc-by-nc-4.0",
]
],
"KishoreK/ActionGemma-9B": [
"ActionGemma-9B (FC)",
"https://huggingface.co/KishoreK/ActionGemma-9B",
"KishoreK",
"MIT",
],
}

INPUT_PRICE_PER_MILLION_TOKEN = {
Expand Down Expand Up @@ -673,7 +679,8 @@
"ibm-granite/granite-20b-functioncalling",
"THUDM/glm-4-9b-chat",
"Salesforce/xLAM-1b-fc-r",
"Salesforce/xLAM-7b-fc-r"
"Salesforce/xLAM-7b-fc-r",
"KishoreK/ActionGemma-9B",
]

# Price got from AZure, 22.032 per hour for 8 V100, Pay As You Go Total Price
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import json

from model_handler.oss_handler import OSSHandler
from model_handler.model_style import ModelStyle

SYSTEM_PROMPT ="""<bos>
<start_of_turn>system
You are an AI assistant for function calling.
For politically sensitive questions, security and privacy issues,
and other non-computer science questions, you will refuse to answer\n
""".strip()

TASK_INSTRUCTION = """
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
""".strip()


FORMAT_INSTRUCTION = """
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
"""

class actionGemmaHandler(OSSHandler):
def __init__(self, model_name, temperature=0.001, top_p=1, max_tokens=512, dtype="bfloat16") -> None:
super().__init__(model_name, temperature, top_p, max_tokens, dtype)
self.model_style = ModelStyle.OSSMODEL

def _format_prompt(query, functions, test_category):
def convert_to_xlam_tool(tools):
'''Convert the Gorilla function call format to xlam format'''
if isinstance(tools, dict):
xlam_tools = {
"name": tools["name"],
"description": tools["description"],
"parameters": tools["parameters"].get("properties", {})
}
required = tools["parameters"].get("required", [])
for param in required:
xlam_tools["parameters"][param]["required"] = True
elif isinstance(tools, list):
xlam_tools = []
for tool in tools:
xlam_tools.append(convert_to_xlam_tool(tool))
else:
xlam_tools = tools
return xlam_tools

tools = convert_to_xlam_tool(functions)
if isinstance(tools, dict):
tools = [tools]

content = f"\n{TASK_INSTRUCTION}\n<end_of_turn>\n"
# content += f"{FORMAT_INSTRUCTION}\n<end_of_turn>\n\n"
content += "<unused0>\n" + json.dumps(tools) + "\n<unused1>\n\n"

content += f"<start_of_turn>user\n{query}<end_of_turn>\n\n"
return SYSTEM_PROMPT + f"\n{content}\n<start_of_turn>assistant"

def inference(
self, test_question, num_gpus, gpu_memory_utilization, format_prompt_func=_format_prompt
):
print(f"[INFO] >> {num_gpus}, {gpu_memory_utilization}")
return super().inference(
test_question, num_gpus, gpu_memory_utilization, format_prompt_func
)

def decode_ast(self,result,language="Python"):
result_list = self.convert_to_dict(result)
return result_list

@staticmethod
def xlam_json_to_python_tool_calls(tool_calls):
"""
Converts a list of function calls in xLAM JSON format to Python format.

Parameters:
tool_calls (list): A list of dictionaries, where each dictionary represents a function call in xLAM JSON format.

Returns:
python_format (list): A list of strings, where each string is a function call in Python format.
"""
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]

python_format = []
for tool_call in tool_calls:
if isinstance(tool_call, dict):
name = tool_call.get('name', "")
arguments = tool_call.get('arguments', {})
args_str = ', '.join([f"{key}={repr(value)}" for key, value in arguments.items()])
python_format.append(f"{name}({args_str})")
else:
print(f"Invalid format: {tool_call}")

return python_format

def decode_execute(self,result):
try:
result_json = json.loads(result)
except:
return result
if isinstance(result_json, list):
tool_calls = result_json
elif isinstance(result_json, dict):
tool_calls = result_json.get('tool_calls', [])
else:
tool_calls = []
function_call = self.xlam_json_to_python_tool_calls(tool_calls)
return function_call

def convert_to_dict(self, input_str):
"""
Convert a JSON-formatted string into a dictionary of tool calls and their arguments.

Parameters:
- input_str (str): A JSON-formatted string containing 'tool_calls' with 'name' and 'arguments'.

Returns:
- list[dict]: A list of dictionaries with tool call names as keys and their arguments as values.
"""
try:
data = json.loads(input_str)
except json.JSONDecodeError:
return input_str

tool_calls = data if isinstance(data, list) else data.get('tool_calls', [])

result_list = [
{tool_call.get('name', ''): tool_call.get('arguments', {})}
for tool_call in tool_calls if isinstance(tool_call, dict)
]

return result_list
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from model_handler.glm_handler import GLMHandler
from model_handler.yi_handler import YiHandler
from model_handler.xlam_handler import xLAMHandler
from model_handler.actionGemma_handler import actionGemmaHandler

handler_map = {
"gorilla-openfunctions-v0": GorillaHandler,
Expand Down Expand Up @@ -95,5 +96,6 @@
"THUDM/glm-4-9b-chat": GLMHandler,
"yi-large-fc": YiHandler,
"Salesforce/xLAM-1b-fc-r": xLAMHandler,
"Salesforce/xLAM-7b-fc-r": xLAMHandler
"Salesforce/xLAM-7b-fc-r": xLAMHandler,
"KishoreK/ActionGemma-9B":actionGemmaHandler,
}