From 255bdcdd4bd6c3e1eff97fb16ff8f3aeac0ff06c Mon Sep 17 00:00:00 2001 From: Supreeth Manyam Date: Mon, 20 Jan 2025 14:53:43 -0800 Subject: [PATCH] Add masking strategies to message transforms --- torchtune/data/_messages.py | 292 ++++++++++++++++++++++++++---------- 1 file changed, 209 insertions(+), 83 deletions(-) diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index a4e00834c2..c5cfb4ef64 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Mapping, Optional, Union +from warnings import warn from torchtune.data._utils import format_content_with_images, load_image @@ -153,8 +154,13 @@ class InputOutputToMessages(Transform): | "user prompt" | "model response" | Args: - train_on_input (bool): Whether the model is trained on the user prompt or not. - Default is False. + masking_strategy (Optional[str]): masking strategy to use for model training. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + Default is "train_on_all". + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked + Note: Multimodal user messages are always masked. column_map (Optional[Dict[str, str]]): a mapping to change the expected "input" and "output" column names to the actual column names in the dataset. Keys should be "input" and "output" and values should be the actual column names. Default is None, @@ -166,6 +172,9 @@ class InputOutputToMessages(Transform): was ``"images/1.jpg"``, the final image path that will be loaded is ``"/home/user/dataset/images/1.jpg"``. If None, assume images are available in current working directory or are located on a remote url. For text-only, leave as None. Default is None. + train_on_input (Optional[bool]): whether the model is trained on the user prompt or not. + Deprecated parameter and will be removed in a future release. + Default is None. Raises: ValueError: @@ -176,12 +185,24 @@ class InputOutputToMessages(Transform): def __init__( self, - train_on_input: bool = False, + masking_strategy: Optional[str] = "train_on_all", column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, image_dir: Optional[Path] = None, + train_on_input: Optional[bool] = None, ): - self.train_on_input = train_on_input + if train_on_input is not None: + warn( + "train_on_input is deprecated and will be removed in a future release. " + "Please use masking_strategy instead.", + DeprecationWarning, + stacklevel=2, + ) + if masking_strategy is None: + masking_strategy = ( + "train_on_all" if train_on_input else "train_on_assistant" + ) + self.masking_strategy = masking_strategy self.new_system_prompt = new_system_prompt self.column_map = column_map @@ -242,13 +263,11 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: Message( role="user", content=content, - masked=not self.train_on_input, eot=True, ), Message( role="assistant", content=output_content, - masked=False, eot=True, ), ] @@ -258,6 +277,7 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: role="system", content=self.new_system_prompt, masked=True, eot=True ) ] + messages + mask_messages(messages, self.masking_strategy) return {"messages": messages} @@ -288,8 +308,12 @@ class ChosenRejectedToMessages(Transform): turns of user and assistant messages. Args: - train_on_input (bool): Whether the model is trained on the user prompt or not. - Default is False. + masking_strategy (Optional[str]): masking strategy to use for model training. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + Default is "train_on_all". + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked column_map (Optional[Dict[str, str]]): a mapping to change the expected "chosen" and "rejected" column names to the actual column names in the dataset. Keys should be "chosen" and "rejected" and values should be the actual column names. @@ -297,6 +321,9 @@ class ChosenRejectedToMessages(Transform): new_system_prompt (Optional[str]): if specified, prepend a system message. This can serve as instructions to guide the model response. Setting this will OVERRIDE any system messages already present in the dataset. Default is None. + train_on_input (Optional[bool]): whether the model is trained on the user prompt or not. + Deprecated parameter and will be removed in a future release. + Default is None. Raises: ValueError: If ``column_map`` is provided and ``chosen`` not in ``column_map``, or @@ -305,11 +332,23 @@ class ChosenRejectedToMessages(Transform): def __init__( self, - train_on_input: bool = False, + masking_strategy: Optional[str] = "train_on_all", column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, + train_on_input: Optional[bool] = None, ): - self.train_on_input = train_on_input + if train_on_input is not None: + warn( + "train_on_input is deprecated and will be removed in a future release. " + "Please use masking_strategy instead.", + DeprecationWarning, + stacklevel=2, + ) + if masking_strategy is None: + masking_strategy = ( + "train_on_all" if train_on_input else "train_on_assistant" + ) + self.masking_strategy = masking_strategy self.new_system_prompt = new_system_prompt if column_map: if "chosen" not in column_map: @@ -329,18 +368,12 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: for message in sample[self._column_map["chosen"]]: if message["role"] == "system" and self.new_system_prompt is not None: continue - message["masked"] = (message["role"] != "assistant") and ( - not self.train_on_input - ) chosen_messages.append(Message.from_dict(message)) rejected_messages = [] for message in sample[self._column_map["rejected"]]: if message["role"] == "system" and self.new_system_prompt is not None: continue - message["masked"] = (message["role"] != "assistant") and ( - not self.train_on_input - ) rejected_messages.append(Message.from_dict(message)) if self.new_system_prompt is not None: @@ -354,7 +387,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: role="system", content=self.new_system_prompt, masked=True, eot=True ) ] + rejected_messages - + mask_messages(chosen_messages, self.masking_strategy) + mask_messages(rejected_messages, self.masking_strategy) return {"chosen": chosen_messages, "rejected": rejected_messages} @@ -389,8 +423,13 @@ class ShareGPTToMessages(Transform): ] Args: - train_on_input (bool): whether the prompt should remain unmasked. For multimodal datasets, ``train_on_input`` - is always False and this value is ignored. Default: False + masking_strategy (Optional[str]): masking strategy to use for model training. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + Default is "train_on_all". + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked + Note: Multimodal user messages are always masked. column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("conversations") to the new column names in the dataset. Key should be "conversations" and value should be the new column name. If None, keep the default "conversations". @@ -406,6 +445,9 @@ class ShareGPTToMessages(Transform): image_tag (Optional[str]): placeholder tags in the text content of each message to be replaced by image special tokens. If images are present and this is None, then will prepend image tokens to the first user message in the sample by default. If text-only, this field is ignored. Default is ``""``. + train_on_input (Optional[bool]): whether the model is trained on the user prompt or not. + Deprecated parameter and will be removed in a future release. + Default is None. Raises: ValueError: If ``column_map`` is provided and ``conversations`` not in ``column_map``. @@ -413,13 +455,25 @@ class ShareGPTToMessages(Transform): def __init__( self, - train_on_input: bool = False, + masking_strategy: Optional[str] = "train_on_all", column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, image_dir: Optional[Path] = None, image_tag: Optional[str] = "", + train_on_input: Optional[bool] = None, ): - self.train_on_input = train_on_input + if train_on_input is not None: + warn( + "train_on_input is deprecated and will be removed in a future release. " + "Please use masking_strategy instead.", + DeprecationWarning, + stacklevel=2, + ) + if masking_strategy is None: + masking_strategy = ( + "train_on_all" if train_on_input else "train_on_assistant" + ) + self.masking_strategy = masking_strategy self.new_system_prompt = new_system_prompt if column_map: if "conversations" not in column_map: @@ -482,14 +536,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: images=[pil_image], ) image_loaded = True - - # If multimodal and user message, always mask - # Otherwise, if user message, mask if train_on_input is False - masked = (role != "assistant") and ( - not self.train_on_input or is_multimodal - ) - messages.append(Message(role=role, content=content, masked=masked)) - + messages.append(Message(role=role, content=content)) + mask_messages(messages, self.masking_strategy) return {"messages": messages} @@ -544,7 +592,13 @@ class OpenAIToMessages(Transform): ] Args: - train_on_input (bool): whether the prompt should remain unmasked. Default: False + masking_strategy (Optional[str]): masking strategy to use for model training. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + Default is "train_on_all". + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked + Note: Multimodal user messages are always masked. column_map (Optional[Dict[str, str]]): a mapping from the expected columns ("messages") to the new column names in the dataset. Key should be "messages" and value should be the new column name. If None, keep the default "messages". @@ -552,6 +606,9 @@ class OpenAIToMessages(Transform): new_system_prompt (Optional[str]): if specified, prepend a system message. This can serve as instructions to guide the model response. Setting this will OVERRIDE any system messages already present in the dataset. Default is None. + train_on_input (Optional[bool]): whether the model is trained on the user prompt or not. + Deprecated parameter and will be removed in a future release. + Default is None. Raises: ValueError: If ``column_map`` is provided and ``messages`` not in ``column_map``. @@ -559,11 +616,23 @@ class OpenAIToMessages(Transform): def __init__( self, - train_on_input: bool = False, + masking_strategy: Optional[str] = "train_on_all", column_map: Optional[Dict[str, str]] = None, new_system_prompt: Optional[str] = None, + train_on_input: Optional[bool] = None, ): - self.train_on_input = train_on_input + if train_on_input is not None: + warn( + "train_on_input is deprecated and will be removed in a future release. " + "Please use masking_strategy instead.", + DeprecationWarning, + stacklevel=2, + ) + if masking_strategy is None: + masking_strategy = ( + "train_on_all" if train_on_input else "train_on_assistant" + ) + self.masking_strategy = masking_strategy self.new_system_prompt = new_system_prompt if column_map: if "messages" not in column_map: @@ -614,7 +683,6 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: for message in sample[self._column_map["messages"]]: if message["role"] == "system" and self.new_system_prompt is not None: continue - masked = (message["role"] != "assistant") and (not self.train_on_input) if isinstance(message["content"], list): content = self._convert_from_openai_content(message["content"]) elif isinstance(message["content"], str): @@ -623,55 +691,12 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: Message( role=message["role"], content=content, - masked=masked, ), ) - + mask_messages(updated_messages, self.masking_strategy) return {"messages": updated_messages} -def validate_messages( - messages: List[Message], -) -> None: - """ - Given a list of messages, ensure that messages form a valid - back-and-forth conversation. An error will be raised if: - - - There is a system message that's not the first message - - There are two consecutive user messages - - An assistant message comes before the first user message - - The message is empty - - Messages are shorter than length of 2 (min. one user-assistant turn) - - - Args: - messages (List[Message]): the messages to validate. - - Raises: - ValueError: If the messages are invalid. - """ - if len(messages) < 2: - raise ValueError( - f"Messages must be at least length 2, but got {len(messages)} messages" - ) - - last_turn = "assistant" - for i, message in enumerate(messages): - if message.role == "assistant" and last_turn != "user": - raise ValueError( - f"Assistant message before expected user message at index {i} in messages" - ) - if message.role == "user" and last_turn == "user": - raise ValueError( - f"Two consecutive user messages at index {i} and {i - 1} in messages" - ) - if message.role == "system" and i > 0: - raise ValueError( - f"System message at index {i} in messages, but system messages must come first" - ) - last_turn = message.role - - class AlpacaToMessages(Transform): """ Message transform class for Alpaca-style datasets with "instruction", "input", and "output" @@ -682,17 +707,38 @@ class AlpacaToMessages(Transform): due to this custom logic. Args: - train_on_input (bool): Whether the model is trained on the user prompt or not. - Default is True. + masking_strategy (Optional[str]): masking strategy to use for model training. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + Default is "train_on_all". + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input", and "output" column names to the actual column names in the dataset. Default is None, keeping the default column names. + train_on_input (Optional[bool]): whether the model is trained on the user prompt or not. + Deprecated parameter and will be removed in a future release. + Default is None. """ def __init__( - self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None + self, + masking_strategy: Optional[str] = "train_on_all", + column_map: Optional[Dict[str, str]] = None, + train_on_input: Optional[bool] = None, ): - self.train_on_input = train_on_input + if train_on_input is not None: + warn( + "train_on_input is deprecated and will be removed in a future release. " + "Please use masking_strategy instead.", + DeprecationWarning, + stacklevel=2, + ) + if masking_strategy is None: + masking_strategy = ( + "train_on_all" if train_on_input else "train_on_assistant" + ) + self.masking_strategy = masking_strategy self.column_map = column_map self.template = { "prompt_input": ( @@ -726,14 +772,94 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: Message( role="user", content=prompt, - masked=not self.train_on_input, eot=True, ), Message( role="assistant", content=sample[key_output], - masked=False, eot=True, ), ] + mask_messages(messages, self.masking_strategy) return {"messages": messages} + + +def validate_messages( + messages: List[Message], +) -> None: + """ + Given a list of messages, ensure that messages form a valid + back-and-forth conversation. An error will be raised if: + + - There is a system message that's not the first message + - There are two consecutive user messages + - An assistant message comes before the first user message + - The message is empty + - Messages are shorter than length of 2 (min. one user-assistant turn) + + + Args: + messages (List[Message]): the messages to validate. + + Raises: + ValueError: If the messages are invalid. + """ + if len(messages) < 2: + raise ValueError( + f"Messages must be at least length 2, but got {len(messages)} messages" + ) + + last_turn = "assistant" + for i, message in enumerate(messages): + if message.role == "assistant" and last_turn != "user": + raise ValueError( + f"Assistant message before expected user message at index {i} in messages" + ) + if message.role == "user" and last_turn == "user": + raise ValueError( + f"Two consecutive user messages at index {i} and {i - 1} in messages" + ) + if message.role == "system" and i > 0: + raise ValueError( + f"System message at index {i} in messages, but system messages must come first" + ) + last_turn = message.role + + +def mask_messages(messages: List[Message], masking_strategy: str) -> None: + """ + Set the masked attribute for each message in the list based on the specified masking strategy. + + Args: + messages (List[Message]): a list of messages to mask. + masking_strategy (str): masking strategy to use. + Must be one of: `train_on_all`, `train_on_assistant`, `train_on_last`. + - ``train_on_all``: both user and assistant messages are unmasked + - ``train_on_assistant``: user messages are masked, only assistant messages are unmasked + - ``train_on_last``: only the last assistant message is unmasked + + Raises: + ValueError: If the masking strategy is not one of the supported strategies: + `train_on_all`, `train_on_assistant`, `train_on_last`. + """ + if masking_strategy not in ["train_on_all", "train_on_assistant", "train_on_last"]: + raise ValueError( + f"masking_strategy must be one of 'train_on_all', 'train_on_assistant', 'train_on_last', got {masking_strategy}" + ) + marked_last_assistant_message = False + for message in reversed(messages): + # System messages are always masked + if message.role == "system": + message.masked = True + continue + if masking_strategy == "train_on_last": + if message.role == "assistant" and not marked_last_assistant_message: + message.masked = False + marked_last_assistant_message = True + else: + message.masked = True + # Multimodal user messages are always masked + elif masking_strategy == "train_on_all": + message.masked = message.role == "user" and message.contains_media + else: # train_on_assistant + message.masked = message.role != "assistant"