diff --git a/docs/config.qmd b/docs/config.qmd
index 120aec8933..ba23384f0c 100644
--- a/docs/config.qmd
+++ b/docs/config.qmd
@@ -127,34 +127,40 @@ datasets:
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
- # Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
+
+ # Custom jinja chat template. Used only if `chat_template: jinja` or empty.
chat_template_jinja:
- # The key in the data example that contains the messages. Default is "messages".
+
+ # Key containing the messages (default: "messages")
field_messages: messages
- # The key in the message turn that contains the role. Default is "role".
+ # Key for role in each message (default: "role")
message_field_role: role
- # The key in the message turn that contains the content. Default is "content".
+ # Key for content in each message (default: "content")
message_field_content: content
- # Optional[Dict[str, List]]. Roles mapping for the messages.
+
+ # Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
user: ["human", "user"]
- assistant: ["gpt", "assistant", "ai"]
+ assistant: ["gpt", "assistant"]
system: ["system"]
+ tool: ["tool"]
- ## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
+ # IMPORTANT: The following fields determine which parts of the conversation to train on.
+ # Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
+ # See examples at `docs/dataset-formats/conversation.qmd`
+ # Note: If the below 4 fields are empty, defaults to training only on the last message.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
- roles_to_train: ["gpt", "assistant"]
+ roles_to_train: ["assistant"] # default
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
- # - turn: train on the EOS token at the end of each trainable turn
+ # - turn (default): train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
- # See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd
index fb9aed3ffa..9f6e8c3604 100644
--- a/docs/dataset-formats/conversation.qmd
+++ b/docs/dataset-formats/conversation.qmd
@@ -68,6 +68,8 @@ We recommend checking the below examples for other usecases.
datasets:
- path: ...
type: chat_template
+ roles_to_train:
+ train_on_eos:
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
@@ -77,7 +79,7 @@ chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
- roles_to_train: ["assistant"]
+ roles_to_train: ["assistant"] # default value
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
@@ -87,7 +89,6 @@ chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer
datasets:
- path: ...
type: chat_template
- roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
@@ -99,7 +100,6 @@ chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message
datasets:
- path: ...
type: chat_template
- roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py
index 35c9311678..5b12130d75 100644
--- a/src/axolotl/prompt_strategies/chat_template.py
+++ b/src/axolotl/prompt_strategies/chat_template.py
@@ -25,8 +25,8 @@ def __init__(
processor=None,
chat_template=None,
max_length=2048,
- message_field_role: str = "from",
- message_field_content: str = "value",
+ message_field_role: str = "role",
+ message_field_content: str = "content",
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
@@ -41,6 +41,7 @@ def __init__(
"assistant": "assistant",
"gpt": "assistant",
"system": "system",
+ "tool": "tool",
}
self.message_field_role = message_field_role
@@ -188,7 +189,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
Tokenizing strategy for instruction-based prompts.
"""
- _messages = "conversations"
+ _messages = "messages"
def __init__(
self,
@@ -279,12 +280,7 @@ def tokenize_prompt(self, prompt):
LOG.debug(f"Should train: {should_train}")
- turn_start_idx, turn_end_idx = self.find_turn(
- conversation_ids=input_ids, turn=index, turn_content=turn
- )
-
- if turn_start_idx == -1 or turn_end_idx == -1:
- LOG.warning(f"Failed to find boundaries for turn {index}")
+ turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
@@ -313,8 +309,8 @@ def tokenize_prompt(self, prompt):
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
- eos_idx = self.find_eos_token(input_ids, turn_end_idx)
- if eos_idx == turn_end_idx:
+ eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
+ if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
@@ -339,75 +335,120 @@ def tokenize_prompt(self, prompt):
"attention_mask": [1] * len(input_ids),
}
- def find_eos_token(self, input_ids, start_idx):
+ def find_first_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1
- def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
+ def find_turn(self, turns: list[dict], turn_idx: int):
"""
Locate the starting and ending indices of the specified turn in a conversation.
"""
- content = turn_content.get("content")
- content_ids = self.tokenizer.encode(content, add_special_tokens=False)
+ # pylint: disable=too-many-return-statements
- LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")
+ if turn_idx >= len(turns):
+ raise ValueError(f"Turn index {turn_idx} out of range")
- if not content_ids:
- LOG.warning(f"Empty content for turn {turn}")
+ # mistral does not output message if it contains only system message
+ if (
+ turn_idx == 0
+ and turns[0].get("role") == "system"
+ and "mistral" in self.tokenizer.name_or_path.lower()
+ ):
return -1, -1
- # For first turn, start from beginning
- if turn == 0:
- start_search_idx = 0
- else:
- # For subsequent turns, find the previous EOS token
- eos_token_id = self.tokenizer.eos_token_id
- eos_count = 0
- start_search_idx = 0
-
- for i, token_id in enumerate(conversation_ids):
- if token_id == eos_token_id:
- eos_count += 1
- if eos_count == turn: # Find the nth EOS token where n = turn
- start_search_idx = i + 1
- break
-
- # we can optimize this to only search for a few tokens from start_search_idx
- # but it would risk missing the content if it's not found within the first few tokens or
- # if start_search_idx cannot be found above.
- last_index = len(conversation_ids) - len(content_ids) + 1
-
- if last_index < start_search_idx:
+ empty_turn = {
+ "role": turns[turn_idx].get("role"),
+ "content": "[[dummy_message]]",
+ }
+
+ # Create conversation versions
+ turns_with_empty = turns[:turn_idx] + [empty_turn]
+ turns_with_content = turns[: turn_idx + 1]
+
+ # Generate the conversation up to the turn, with final turn replaced with dummy content
+ dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
+
+ # Generate the conversation up to the turn, with final turn included
+ full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
+
+ if not full_ids or not dummy_ids:
+ LOG.warning(f"Empty template generated for turn {turn_idx}")
+ return -1, -1
+
+ # Find first difference (start of content)
+ start_idx = None
+ min_len = min(len(dummy_ids), len(full_ids))
+ for i in range(min_len):
+ if dummy_ids[i] != full_ids[i]:
+ start_idx = i
+ break
+
+ if start_idx is None:
+ LOG.warning(f"Could not find content start boundary for turn {turn_idx}")
+ return -1, -1
+
+ # Find last difference (end of content)
+ end_idx = None
+ for i in range(min_len):
+ dummy_pos = len(dummy_ids) - 1 - i
+ full_pos = len(full_ids) - 1 - i
+ if dummy_ids[dummy_pos] != full_ids[full_pos]:
+ end_idx = full_pos + 1 # Add one to include the last token when slice
+ break
+
+ if end_idx is None:
+ LOG.warning(f"Could not find content end boundary for turn {turn_idx}")
+ return -1, -1
+
+ if end_idx < start_idx:
+ LOG.warning(
+ f"Content end boundary is before start boundary for turn {turn_idx}"
+ )
+ return -1, -1
+
+ if end_idx == start_idx:
LOG.warning(
- f"last_index to search is less than start_search_idx for turn {turn}"
+ f"Content end boundary is the same as start boundary for turn {turn_idx}. This is likely an empty turn."
)
return -1, -1
- # Search for content starting from start_search_idx
- first_elem = content_ids[0]
- for i in range(start_search_idx, last_index):
- # Quick check of first element before doing full comparison
- if conversation_ids[i] == first_elem:
- # Check if the rest of the content matches
- if conversation_ids[i : i + len(content_ids)] == content_ids:
- LOG.debug(f"Found turn {turn} content at position {i}")
- return i, i + len(content_ids)
+ LOG.debug(f"Content boundaries: {start_idx}, {end_idx}")
+ LOG.debug(
+ f"Content tokens: {self.tokenizer.convert_ids_to_tokens(full_ids[start_idx:end_idx])}"
+ )
- return -1, -1
+ return start_idx, end_idx
def get_conversation_thread(self, prompt):
- turns = [
- {
- "role": self.prompter.roles[t[self.prompter.message_field_role]],
- "content": t[self.prompter.message_field_content],
- "training": t.get(self.prompter.message_field_training),
- "training_detail": t.get(self.prompter.message_field_training_detail),
- }
- for t in prompt[self.messages]
+ turns = []
+ optional_keys = [
+ "tool_calls", # tool that 'assistant' calls
+ "name", # name of tool given by 'tool'
+ "tool_call_id", # mistral/mixtral requires this
]
+ for message in prompt[self.messages]:
+ turn = {
+ "role": self.prompter.roles[message[self.prompter.message_field_role]],
+ "training": message.get(self.prompter.message_field_training),
+ "training_detail": message.get(
+ self.prompter.message_field_training_detail
+ ),
+ }
+
+ # do not add content if None as it may conflict with some templates due to tools
+ content = message.get(self.prompter.message_field_content, None)
+ if content is not None:
+ turn["content"] = content
+
+ for key in optional_keys:
+ value = message.get(key, None)
+ if value is not None:
+ turn[key] = value
+
+ turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
@@ -446,8 +487,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
- "roles_to_train": ds_cfg.get("roles_to_train", []),
- "train_on_eos": ds_cfg.get("train_on_eos", None),
+ "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
+ "train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index ffe5e24853..682a0449e8 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -25,7 +25,7 @@
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"llama3_2_vision": '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n',
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
- "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
+ "phi_35": "{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}",
"deepseek_v2": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|User|>' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '<|Assistant|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|Assistant|>' }}{% endif %}",
"jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n',
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
diff --git a/tests/prompt_strategies/conftest.py b/tests/prompt_strategies/conftest.py
index 00a2bf3004..fdfcbff438 100644
--- a/tests/prompt_strategies/conftest.py
+++ b/tests/prompt_strategies/conftest.py
@@ -7,6 +7,8 @@
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
+from axolotl.utils.chat_templates import _CHAT_TEMPLATES
+
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
@@ -59,7 +61,52 @@ def fixture_basic_dataset():
)
-@pytest.fixture(name="llama3_tokenizer")
+@pytest.fixture(name="toolcalling_dataset")
+def fixture_toolcalling_dataset():
+ # pylint: disable=duplicate-code
+ return Dataset.from_list(
+ [
+ {
+ "messages": [
+ {
+ "role": "system",
+ "content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.",
+ },
+ {
+ "role": "user",
+ "content": "Hey, what's the temperature in Paris right now?",
+ },
+ {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_temperature",
+ "arguments": {
+ "location": "Paris, France",
+ "unit": "celsius",
+ },
+ },
+ }
+ ],
+ },
+ {
+ "role": "tool",
+ "name": "get_current_temperature",
+ "content": "22.0",
+ },
+ {
+ "role": "assistant",
+ "content": "The temperature in Paris is 22.0 degrees Celsius.",
+ },
+ ]
+ }
+ ]
+ )
+
+
+@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
def fixture_llama3_tokenizer():
hf_hub_download(
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
@@ -77,7 +124,53 @@ def fixture_llama3_tokenizer():
return tokenizer
-@pytest.fixture(name="phi35_tokenizer")
+@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
+def fixture_mistralv03_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained(
+ "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
+ )
+ return tokenizer
+
+
+@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
def fixture_phi35_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
return tokenizer
+
+
+@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
+def fixture_gemma2_tokenizer():
+ tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
+
+ return tokenizer
+
+
+@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
+def fixture_mistralv03_chat_template_jinja_w_system() -> str:
+ return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
+
+
+@pytest.fixture(name="gemma2_tokenizer_chat_template_jinja")
+def fixture_gemma2_chat_template_jinja_w_system() -> str:
+ return "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}"
+
+
+@pytest.fixture(name="llama3_2_vision_chat_template_jinja")
+def fixture_llama3_2_vision_with_hardcoded_date() -> str:
+ """Hardcodes the date in the template to avoid the need for date logic in the prompt"""
+
+ template = _CHAT_TEMPLATES["llama3_2_vision"]
+
+ old_date_logic = """{%- if not date_string is defined %}
+ {%- if strftime_now is defined %}
+ {%- set date_string = strftime_now("%d %b %Y") %}
+ {%- else %}
+ {%- set date_string = "26 Jul 2024" %}
+ {%- endif %}
+{%- endif %}"""
+
+ new_date_logic = """{%- set date_string = "17 Dec 2024" %}"""
+
+ modified_template = template.replace(old_date_logic, new_date_logic)
+
+ return modified_template
diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py
index 4ec12b82cb..8ec4fa1191 100644
--- a/tests/prompt_strategies/test_chat_templates.py
+++ b/tests/prompt_strategies/test_chat_templates.py
@@ -140,7 +140,6 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset):
1781, 26966, 32007, # user eot
32001, # assistant
1781, 26966, 32007, # assistant eot
- 32000, # eos
]
expected_labels = [
-100, # user
@@ -151,7 +150,6 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset):
-100, -100, -100, # user eot
-100, # assistant
1781, 26966, 32007, # assistant eot
- 32000, # eos
]
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
@@ -230,7 +228,10 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ llama3_tokenizer,
+ chat_template=get_chat_template("llama3"),
+ message_field_role="from",
+ message_field_content="value",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -238,6 +239,7 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
sequence_len=512,
roles_to_train=["gpt"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -283,7 +285,10 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ llama3_tokenizer,
+ chat_template=get_chat_template("llama3"),
+ message_field_role="from",
+ message_field_content="value",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -291,6 +296,7 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
sequence_len=512,
roles_to_train=["human"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -336,7 +342,10 @@ def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ llama3_tokenizer,
+ chat_template=get_chat_template("llama3"),
+ message_field_role="from",
+ message_field_content="value",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -344,6 +353,7 @@ def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
sequence_len=512,
roles_to_train=["system", "human"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -389,5 +399,148 @@ def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
), f"Labels mismatch: {labels} != {expected_labels}"
+class TestAssistantToolCallingChatTemplateLlama32Vision:
+ """
+ Test class for assistant style datasets with tool_calling prompts using the llama-32_vision chat template.
+ """
+
+ def test_llama32vision_train_on_assistant(
+ self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
+ ):
+ LOG.info(
+ "Testing assistant style datasets with tool_calling with llama-32 chat template, training on assistant"
+ )
+
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer,
+ chat_template=get_chat_template(
+ "jinja", jinja_template=llama3_2_vision_chat_template_jinja
+ ),
+ message_field_role="role",
+ message_field_content="content",
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ train_on_eos="turn",
+ sequence_len=512,
+ roles_to_train=["assistant"],
+ )
+
+ res = strategy.tokenize_prompt(toolcalling_dataset[0])
+
+ input_ids = res["input_ids"]
+ labels = res["labels"]
+
+ # fmt: off
+ expected_input_ids = [
+ 128000, # bos
+ 128006, 9125, 128007, 271, # system header
+ 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
+ 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
+ 128006, 882, 128007, 271, # user header
+ 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
+ 128006, 78191, 128007, 271, # assistant header
+ 5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
+ 128006, 23799, 4690, 128007, 271, # tool header
+ 1, 1313, 13, 15, 1, 128009, # tool message
+ 128006, 78191, 128007, 271, # assistant header
+ 791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
+ ]
+
+ expected_labels = [
+ IGNORE_TOKEN_ID, # bos
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
+ 5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
+ 791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
+ ]
+ # fmt: on
+
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+
+ assert (
+ labels == expected_labels
+ ), f"Labels mismatch: {labels} != {expected_labels}"
+
+ def test_llama32vision_train_on_tools(
+ self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
+ ):
+ LOG.info(
+ "Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
+ )
+ # pylint: disable=duplicate-code
+
+ strategy = ChatTemplateStrategy(
+ ChatTemplatePrompter(
+ llama3_tokenizer,
+ chat_template=get_chat_template(
+ "jinja", jinja_template=llama3_2_vision_chat_template_jinja
+ ),
+ message_field_role="role",
+ message_field_content="content",
+ ),
+ tokenizer=llama3_tokenizer,
+ train_on_inputs=False,
+ train_on_eos="turn",
+ sequence_len=512,
+ roles_to_train=["assistant", "tool"],
+ )
+
+ res = strategy.tokenize_prompt(toolcalling_dataset[0])
+
+ input_ids = res["input_ids"]
+ labels = res["labels"]
+
+ # fmt: off
+ expected_input_ids = [
+ 128000, # bos
+ 128006, 9125, 128007, 271, # system header
+ 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
+ 2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
+ 128006, 882, 128007, 271, # user header
+ 19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
+ 128006, 78191, 128007, 271, # assistant header
+ 5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
+ 128006, 23799, 4690, 128007, 271, # tool header
+ 1, 1313, 13, 15, 1, 128009, # tool message
+ 128006, 78191, 128007, 271, # assistant header
+ 791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
+ ]
+
+ expected_labels = [
+ IGNORE_TOKEN_ID, # bos
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
+ 5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
+ IGNORE_TOKEN_ID, 1313, 13, 15, IGNORE_TOKEN_ID, 128009, # tool message
+ IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
+ 791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
+ ]
+ # fmt: on
+
+ assert (
+ input_ids == expected_input_ids
+ ), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
+
+ assert (
+ labels == expected_labels
+ ), f"Labels mismatch: {labels} != {expected_labels}"
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py
index be8e3ccdf9..7d09b059cc 100644
--- a/tests/prompt_strategies/test_chat_templates_advanced.py
+++ b/tests/prompt_strategies/test_chat_templates_advanced.py
@@ -4,8 +4,12 @@
import logging
import unittest
+from copy import deepcopy
+import pytest
from datasets import Dataset
+from tokenizers import AddedToken
+from transformers import PreTrainedTokenizer
from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter,
@@ -17,7 +21,30 @@
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
-
+PARAMETRIZE_KEYS = "tokenizer, chat_template, chat_template_jinja, eos_token"
+PARAMETRIZE_PARAMS = [
+ ("llama3_tokenizer", "llama3", None, None),
+ ("llama3_tokenizer", "chatml", None, "<|im_end|>"),
+ (
+ "mistralv03_tokenizer",
+ "jinja",
+ "mistralv03_tokenizer_chat_template_jinja",
+ "[/INST]",
+ ),
+ (
+ "gemma2_tokenizer",
+ "jinja",
+ "gemma2_tokenizer_chat_template_jinja",
+ "",
+ ),
+ ("phi35_tokenizer", "phi_35", None, "<|end|>"),
+]
+
+
+@pytest.mark.parametrize(
+ PARAMETRIZE_KEYS,
+ PARAMETRIZE_PARAMS,
+)
class TestChatTemplateConfigurations:
"""
Test class for various configurations of ChatTemplateStrategy.
@@ -31,167 +58,318 @@ def find_sublist(full_list, sub_list):
return index
return -1
- def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
+ @staticmethod
+ def setup_tokenizer(
+ tokenizer_name,
+ chat_template,
+ chat_template_jinja=None,
+ eos_token=None,
+ request=None,
+ ) -> tuple[PreTrainedTokenizer, str]:
+ """
+ Helper function to set up the tokenizer and chat template for the test.
+ """
+ tokenizer = deepcopy(request.getfixturevalue(tokenizer_name))
+ if chat_template == "jinja":
+ chat_template_jinja = request.getfixturevalue(chat_template_jinja)
+ if eos_token:
+ tokenizer.add_special_tokens(
+ {
+ "eos_token": AddedToken(
+ eos_token, rstrip=False, lstrip=False, normalized=False
+ )
+ }
+ )
+ if tokenizer.__class__.__name__ in (
+ "LlamaTokenizerFast",
+ "CodeLlamaTokenizerFast",
+ ):
+ tokenizer.update_post_processor()
+ return tokenizer, chat_template_jinja
+
+ def _should_skip_turn(self, tokenizer, turn, turn_idx, start_idx, end_idx):
+ """Helper method to determine if a turn should be skipped in testing.
+ This is used to skip system messages for Mistral as the template does not output them without more turns.
+ """
+ if (
+ turn_idx == 0
+ and turn.get("from") in ["system", "context"]
+ and "mistral" in tokenizer.name_or_path.lower()
+ ):
+ assert (
+ start_idx == -1 and end_idx == -1
+ ), "Expected system message to be skipped"
+ return True
+ return False
+
+ def test_train_on_inputs_true(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with train_on_inputs=True")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
+ turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- # Verify that assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ # Verify assistant responses are labeled
+ for i, turn in enumerate(basic_dataset[0]["conversations"]):
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
+
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ response = turn["value"]
+
+ assert response in decoded_response, (
+ f"Response {response} not found in index {start_idx}:{end_idx} "
+ f"decoded:{decoded_response}"
)
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
+
assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- # Check the behavior of human inputs
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- labeled = all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(input_ids)]
- )
- LOG.debug(
- f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
- )
+ label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
- def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing with train_on_inputs=False")
+ def test_train_on_inputs_false(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
+ LOG.info("Testing with train_on_inputs=False, on assistant only")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
+ turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- # Verify that only assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
-
- # Verify that human inputs are not labeled
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- LOG.debug(
- f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
+ # Process all turns and verify correct labeling based on role
+ for i, turn in enumerate(basic_dataset[0]["conversations"]):
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
+
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ response = turn["value"]
+
+ assert response in decoded_response, (
+ f"Response {response} not found in index {start_idx}:{end_idx} "
+ f"decoded:{decoded_response}"
)
- assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
- assert all(
- label == IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(input_ids)]
- ), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
- def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
- LOG.info("Testing roles_to_train with assistant only")
+ # Verify that assistant responses are labeled and other inputs are not
+ is_assistant = turn["from"] == "assistant"
+ if is_assistant:
+ assert all(
+ label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
+ else:
+ assert all(
+ label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
+
+ def test_roles_to_train_human_assistant_only(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
+ LOG.info("Testing roles_to_train with human assistant only")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
- roles_to_train=["assistant"],
+ roles_to_train=["assistant", "human"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- # Verify that only assistant responses are labeled
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
+ strategy.messages = "conversations"
+ res = strategy.tokenize_prompt(basic_dataset[0])
+ turns = strategy.get_conversation_thread(basic_dataset[0])
+ labels = res["labels"]
+ input_ids = res["input_ids"]
+
+ # Process all turns and verify correct labeling based on role
+ for i, turn in enumerate(basic_dataset[0]["conversations"]):
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
+
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ response = turn["value"]
+
+ assert response in decoded_response, (
+ f"Response {response} not found in index {start_idx}:{end_idx} "
+ f"decoded:{decoded_response}"
)
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
- def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
+ # Verify that non-system responses are labeled and system are not
+ should_be_labelled = turn["from"] != "system"
+ if should_be_labelled:
+ assert all(
+ label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
+ else:
+ assert all(
+ label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
+
+ def test_roles_to_train_all(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing roles_to_train with all roles")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
+ turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
# Verify that all responses are labeled (except for special tokens)
- all_responses = [
- "Hello",
- "Hi there!",
- "How are you?",
- "I'm doing well, thank you!",
- ]
- for response in all_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- LOG.debug(
- f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
- )
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
+ for i, turn in enumerate(basic_dataset[0]["conversations"]):
+ response = turn["value"]
+
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
- def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ assert (
+ response in decoded_response
+ ), f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
+
+ assert all(
+ label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
+
+ def test_empty_roles_to_train(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with empty roles_to_train")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
train_on_eos="none", # Add this line
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
@@ -201,23 +379,42 @@ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
- def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
+ def test_train_on_eos_all(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with train_on_eos='all'")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="all",
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- eos_token_id = llama3_tokenizer.eos_token_id
+ eos_token_id = tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
@@ -228,73 +425,122 @@ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
- def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
+ def test_train_on_eos_turn(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with train_on_eos='turn'")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="turn",
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
+ turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- eos_token_id = llama3_tokenizer.eos_token_id
- assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
+ eos_token_id = tokenizer.eos_token_id
+ # Process all turns and verify EOS token labeling
+ for i, turn in enumerate(basic_dataset[0]["conversations"]):
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
+
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ response = turn["value"]
- for response in assistant_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- assert start_idx != -1, f"Could not find '{response}' in input_ids"
+ assert response in decoded_response, (
+ f"Response {response} not found in index {start_idx}:{end_idx} "
+ f"decoded:{decoded_response}"
+ )
- eos_idx = start_idx + len(response_ids)
+ # Find the EOS token after this turn
+ eos_idx = end_idx
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
- assert (
- labels[eos_idx] != IGNORE_TOKEN_ID
- ), f"Expected EOS token after assistant response '{response}' to be labeled"
-
- # Check that EOS tokens after human inputs are not labeled
- human_inputs = ["Hello", "How are you?"]
- for input_text in human_inputs:
- input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, input_ids)
- assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
- eos_idx = start_idx + len(input_ids)
- while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
- eos_idx += 1
+ LOG.debug(
+ f"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}"
+ )
- assert (
- labels[eos_idx] == IGNORE_TOKEN_ID
- ), f"Expected EOS token after human input '{input_text}' to not be labeled"
+ LOG.debug(
+ f"Labels for turn {i}: {labels[start_idx:end_idx]}, EOS label: {labels[eos_idx]}"
+ )
- def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
+ # Verify EOS token labeling based on role
+ is_assistant = turn["from"] == "assistant"
+ if is_assistant:
+ assert (
+ labels[eos_idx] != IGNORE_TOKEN_ID
+ ), f"Expected EOS token after assistant response '{response}' to be labeled"
+ else:
+ assert (
+ labels[eos_idx] == IGNORE_TOKEN_ID
+ ), f"Expected EOS token after non-assistant input '{response}' to not be labeled"
+
+ def test_train_on_eos_last(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with train_on_eos='last'")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="last",
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- eos_token_id = llama3_tokenizer.eos_token_id
+ eos_token_id = tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
@@ -311,23 +557,42 @@ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
- def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
+ def test_train_on_eos_none(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with train_on_eos='none'")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer, chat_template=get_chat_template("llama3")
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
train_on_eos="none",
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- eos_token_id = llama3_tokenizer.eos_token_id
+ eos_token_id = tokenizer.eos_token_id
eos_indices = [
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
@@ -338,43 +603,75 @@ def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
- def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
+ def test_drop_system_message(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ basic_dataset,
+ request,
+ ):
LOG.info("Testing with drop_system_message=True")
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer,
- chat_template=get_chat_template("llama3"),
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
drop_system_message=True,
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
+ strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
- system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
+ decoded_message = tokenizer.decode(input_ids)
assert (
- self.find_sublist(input_ids, system_ids) == -1
+ system_message not in decoded_message
), "Expected system message to be dropped"
- def test_custom_roles(self, llama3_tokenizer):
+ def test_custom_roles(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ request,
+ ):
LOG.info("Testing with custom roles mapping")
custom_roles = {
"user": ["human", "user"],
"assistant": ["ai", "assistant"],
"system": ["context"],
}
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer,
- chat_template=get_chat_template("llama3"),
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
roles=custom_roles,
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["ai"],
@@ -389,46 +686,65 @@ def test_custom_roles(self, llama3_tokenizer):
{"from": "ai", "value": "I'm doing well, thank you!"},
]
- modified_dataset = Dataset.from_dict(
- {"conversations": [modified_conversations]}
- )
+ modified_dataset = Dataset.from_dict({"messages": [modified_conversations]})
res = strategy.tokenize_prompt(modified_dataset[0])
+ turns = strategy.get_conversation_thread(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- # Check if AI responses are labeled correctly
- ai_responses = ["Hi there!", "I'm doing well, thank you!"]
- for response in ai_responses:
- response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, response_ids)
- assert start_idx != -1, f"Could not find response '{response}' in input_ids"
- assert all(
- label != IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(response_ids)]
- ), f"Expected labels for AI response '{response}' to be set"
-
- # Check if human messages are not labeled
- human_messages = ["Hello", "How are you?"]
- for message in human_messages:
- message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
- start_idx = self.find_sublist(input_ids, message_ids)
- assert start_idx != -1, f"Could not find message '{message}' in input_ids"
- assert all(
- label == IGNORE_TOKEN_ID
- for label in labels[start_idx : start_idx + len(message_ids)]
- ), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
+ # Process all turns and verify labeling
+ for i, turn in enumerate(modified_dataset[0]["messages"]):
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
+
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ response = turn["value"]
- def test_message_field_training(self, llama3_tokenizer):
+ assert response in decoded_response, (
+ f"Response {response} not found in index {start_idx}:{end_idx} "
+ f"decoded:{decoded_response}"
+ )
+
+ # Check if responses are labeled correctly based on role
+ is_ai = turn["from"] == "ai"
+ if is_ai:
+ assert all(
+ label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for AI response '{response}' to be set"
+ else:
+ assert all(
+ label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
+ ), f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
+
+ def test_message_field_training(
+ self,
+ tokenizer,
+ chat_template,
+ chat_template_jinja,
+ eos_token,
+ request,
+ ):
LOG.info("Testing with message_field_training")
+
+ tokenizer, chat_template_jinja = self.setup_tokenizer(
+ tokenizer, chat_template, chat_template_jinja, eos_token, request
+ )
+
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
- llama3_tokenizer,
- chat_template=get_chat_template("llama3"),
+ tokenizer,
+ chat_template=get_chat_template(
+ chat_template, jinja_template=chat_template_jinja
+ ),
message_field_training="train",
message_field_training_detail="train_detail",
+ message_field_role="from",
+ message_field_content="value",
),
- tokenizer=llama3_tokenizer,
+ tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=[],
@@ -457,62 +773,65 @@ def test_message_field_training(self, llama3_tokenizer):
{"from": "assistant", "value": "Hi there!", "train": True},
]
- modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
+ modified_dataset = Dataset.from_dict({"messages": [modified_conversation]})
res = strategy.tokenize_prompt(modified_dataset[0])
+ turns = strategy.get_conversation_thread(modified_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
- # Function to find all occurrences of a sublist
- def find_all_sublists(full_list, sub_list):
- indices = []
- for index in range(len(full_list) - len(sub_list) + 1):
- if full_list[index : index + len(sub_list)] == sub_list:
- indices.append(index)
- return indices
-
- # Keep track of which occurrences we've processed
- processed_occurrences = {}
- # Check if messages are labeled correctly based on train or train_detail
- for i, turn in enumerate(modified_conversation):
- turn_tokens = llama3_tokenizer.encode(
- turn["value"], add_special_tokens=False
- )
- occurrences = find_all_sublists(input_ids, turn_tokens)
- turn_key = turn["value"]
- if turn_key not in processed_occurrences:
- processed_occurrences[turn_key] = 0
- current_occurrence = processed_occurrences[turn_key]
+ def verify_labels(labels_span, should_train, context_message):
+ """Helper to verify if a span of labels matches expected training state"""
+ if should_train:
+ assert all(
+ label != IGNORE_TOKEN_ID for label in labels_span
+ ), f"Expected all labels for {context_message} to be set, but got {labels_span}"
+ else:
+ assert all(
+ label == IGNORE_TOKEN_ID for label in labels_span
+ ), f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
- if current_occurrence >= len(occurrences):
- assert (
- False
- ), f"Not enough occurrences found for message: {turn['value']}"
+ # Process all turns and verify labeling
+ for i, turn in enumerate(modified_dataset[0]["messages"]):
+ start_idx, end_idx = strategy.find_turn(turns=turns, turn_idx=i)
+
+ if self._should_skip_turn(tokenizer, turn, i, start_idx, end_idx):
+ continue
+
+ decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
+ response = turn["value"]
- start_idx = occurrences[current_occurrence]
- processed_occurrences[turn_key] += 1
- end_idx = start_idx + len(turn_tokens)
+ assert response in decoded_response, (
+ f"Response {response} not found in index {start_idx}:{end_idx} "
+ f"decoded:{decoded_response}"
+ )
LOG.debug(
- f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
+ f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', "
+ f"start_idx={start_idx}, end_idx={end_idx}"
)
- if "train_detail" in turn:
- # Get token offsets
- tokenized_output = llama3_tokenizer(
+ if turn.get("train_detail", None) is not None:
+ # Handle detailed token-level training control
+ tokenized_output = tokenizer(
turn["value"], return_offsets_mapping=True, add_special_tokens=False
)
+ assert tokenized_output["input_ids"] == input_ids[start_idx:end_idx], (
+ f"Tokenized input mismatch for turn: {turn['value']}\n"
+ f"Expected: {input_ids[start_idx:end_idx]}\nActual: {tokenized_output['input_ids']}\n"
+ f"This will likely be a mismatch between template content and encoded content"
+ )
+
token_offsets = tokenized_output["offset_mapping"]
- # Adjust token offsets as done in the implementation
- for i in range(len(token_offsets) - 1):
- token_offsets[i] = (
- token_offsets[i][0],
- token_offsets[i + 1][0] - 1,
+ # Adjust token offsets
+ for j in range(len(token_offsets) - 1):
+ token_offsets[j] = (
+ token_offsets[j][0],
+ token_offsets[j + 1][0] - 1,
)
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
- # Adjust train_details
adjusted_train_details = strategy.prompter.adjust_train_details(
turn["train_detail"], token_offsets
)
@@ -520,12 +839,20 @@ def find_all_sublists(full_list, sub_list):
LOG.debug(f"Original train_details: {turn['train_detail']}")
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
- # Handle train_detail
- token_offsets = strategy.prompter.get_offsets_for_train_detail(
+ # Get and verify token offsets
+ turn_tokens = input_ids[start_idx:end_idx]
+ token_offsets_unmasked = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
mask_untrainable=False,
)
+
+ for i, offset in enumerate(token_offsets_unmasked):
+ assert token_offsets[i][0] == offset, (
+ f"Token start offsets mismatch for turn: {turn['value']}\n"
+ f"Expected: {token_offsets[i][0]}\nActual: {offset}"
+ )
+
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
text=turn["value"],
train_details=adjusted_train_details,
@@ -533,6 +860,7 @@ def find_all_sublists(full_list, sub_list):
)
LOG.debug(f"Token offsets: {token_offsets_masked}")
+ # Verify expected labels against actual labels
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
for i, offset in enumerate(token_offsets_masked):
if offset != IGNORE_TOKEN_ID:
@@ -544,17 +872,17 @@ def find_all_sublists(full_list, sub_list):
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
+ # Verify each detail section
for detail in adjusted_train_details:
- # Find the token indices that correspond to the character offsets
detail_start = start_idx + next(
- i
- for i, offset in enumerate(token_offsets)
+ j
+ for j, offset in enumerate(token_offsets_unmasked)
if offset >= detail["begin_offset"]
)
detail_end = start_idx + next(
(
- i
- for i, offset in enumerate(token_offsets)
+ j
+ for j, offset in enumerate(token_offsets_unmasked)
if offset > detail["end_offset"]
),
len(token_offsets),
@@ -564,70 +892,21 @@ def find_all_sublists(full_list, sub_list):
detail["begin_offset"] : detail["end_offset"] + 1
]
detail_labels = labels[detail_start:detail_end]
- detail_input_ids = input_ids[detail_start:detail_end]
- LOG.debug(
- f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
- )
- LOG.debug(f"Detail input_ids: {detail_input_ids}")
- LOG.debug(f"Detail labels: {detail_labels}")
- LOG.debug(
- f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
- )
- LOG.debug(
- f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
+ context = (
+ f"detail (ind {detail_start}:{detail_end}): '{detail_text}'\n"
+ f"decoded: '{tokenizer.decode(input_ids[detail_start:detail_end])}')"
)
-
- if detail["train"]:
- assert all(
- label != IGNORE_TOKEN_ID for label in detail_labels
- ), (
- f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
- f"Labels({detail_start}:{detail_end}): {detail_labels}, "
- f"InputIDs: {detail_input_ids}, "
- f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
- )
- else:
- assert all(
- label == IGNORE_TOKEN_ID for label in detail_labels
- ), (
- f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
- f"Labels({detail_start}:{detail_end}): {detail_labels}, "
- f"InputIDs: {detail_input_ids}, "
- f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
- )
+ verify_labels(detail_labels, detail["train"], context)
else:
+ # Handle regular turn-level training control
should_train = turn.get("train", False)
turn_labels = labels[start_idx:end_idx]
-
- LOG.debug(f"Should train: {should_train}")
- LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
- LOG.debug(f"Turn labels: {turn_labels}")
- LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
- LOG.debug(
- f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
- )
-
- if should_train:
- assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
- f"Expected all labels for '{turn['value']}' to be set\n"
- f"Labels({start_idx}:{end_idx}): {turn_labels}, "
- f"InputIDs: {input_ids[start_idx:end_idx]}, "
- f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
- )
- else:
- assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
- f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
- f"Labels({start_idx}:{end_idx}): {turn_labels}, "
- f"InputIDs: {input_ids[start_idx:end_idx]}, "
- f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
- )
-
- LOG.debug(
- f"Processed turn: {turn['from']}, content: '{turn['value']}', "
- f"start_idx: {start_idx}, end_idx: {end_idx}, "
- f"labels: {labels[start_idx:end_idx]}"
+ context = (
+ f"turn (ind {start_idx}:{end_idx}): '{turn['value']}'\n"
+ f"decoded: '{decoded_response}')"
)
+ verify_labels(turn_labels, should_train, context)
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")