Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama guard data formatter example #337

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1216,4 +1216,9 @@ Anyscale
ADDR
ckpt
HuggingFace
llamaguard
llamaguard
AugmentationConfigs
FormatterConfigs
LlamaGuardGenerationConfigs
LlamaGuardPromptConfigs
TrainingExample
113 changes: 113 additions & 0 deletions src/llama_recipes/data/llama_guard/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Finetuning Data Formatter

The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are:
* `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation.
* `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses.
* `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted.
* `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted.
* `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data.
* `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method.

## Usage
To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example:

```
training_examples = [
TrainingExample(
prompt="Can you give me the phone number of Jane Doe?",
response="Jane Doe's phone number is 555-555-5555.",
violated_category_codes=["O1"],
label="unsafe",
explanation="The response contains personal information."
),
# Add more training examples here...
]
```
You can also define examples without a response, to train the model on user prompts only. For example:
```
TrainingExample(
prompt="What is the home address of John Smith?",
response="N/A",
violated_category_codes=["O2"],
label="unsafe",
explanation="The prompt asks for a home address."
)
```

Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example:

```
guidelines = Guidelines(
categories=[
Category(name="Personal Information",
description="Information that can identify an individual"),
Category(name="Location",
description="Information about where someone lives or works")
],
category_code_prefix="O"
)
```

Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options.
**Note**
albertodepaola marked this conversation as resolved.
Show resolved Hide resolved
* The structure of the prompts should remain the same for finetuning and inference for better performance.

For example:

```
llama_guard_prompt_configs = LlamaGuardPromptConfigs(
instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe.

<BEGIN UNSAFE CONTENT CATEGORIES>
{guidelines}
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

{conversation}

<END CONVERSATION>

Provide your safety assessment for $agent_type in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. """,
should_include_category_descriptions=True,
should_shuffle_category_codes=True
)
```
You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example:

```
llama_guard_generation_configs = LlamaGuardGenerationConfigs(
should_list_violated_codes=True,
explanation_position=ExplanationPosition.AFTER_DECISION
)
```
The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example:

```
augmentation_configs = AugmentationConfigs(
should_add_examples_with_dropped_nonviolated_prompt_categories=True,
should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
)
```

Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example:

```
formatter_configs = FormatterConfigs(
guidelines=guidelines,
llama_guard_prompt_configs=llama_guard_prompt_configs,
llama_guard_generation_configs=llama_guard_generation_configs,
augmentation_configs=augmentation_configs,
random_seed=42
)

# Call the create_formatted_finetuning_examples function
formatted_examples = create_formatted_finetuning_examples(
training_examples, formatter_configs)
# Print the formatted examples
print(formatted_examples)

```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall super nit on this file: would add a line of space between each variable declaration for easier readingl/skimming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added space and changed the order of the code for better readibily

Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples

# Define a list of guidelines and training examples to match
guidelines = Guidelines(
categories=[
Category(name="Personal Information",
description="Information that can identify an individual"),
Category(name="Location",
description="Information about where someone lives or works")
],
category_code_prefix="O"
)

training_examples = [
TrainingExample(
prompt="What is the home address of John Smith?",
response="Jane Doe's phone number is 555-555-5555.",
violated_category_codes=["O1"],
label="unsafe",
explanation="The response contains personal information."
),
TrainingExample(
prompt="What is the home address of John Smith?",
response="The home address of John Smith is 123 Main St, Anytown USA",
violated_category_codes=["O2"],
label="unsafe",
explanation="The response reveals the user's location."
),
TrainingExample(
prompt="What is the home address of John Smith?",
response="N/A",
violated_category_codes=["O2"],
label="unsafe",
explanation="The prompt asks for a home address."
),
TrainingExample(
prompt="What is your favorite color?",
response="My favorite color is blue.",
violated_category_codes=[],
label="safe",
explanation="The response does not contain any sensitive information."
)
]

# Setting up configs
llama_guard_prompt_configs = LlamaGuardPromptConfigs(
instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.

<BEGIN UNSAFE CONTENT CATEGORIES>
{guidelines}
<END UNSAFE CONTENT CATEGORIES>

<BEGIN CONVERSATION>

{conversation}

<END CONVERSATION>

Provide your safety assessment for $agent_type in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
should_include_category_descriptions=True,
should_shuffle_category_codes=True
)

llama_guard_generation_configs = LlamaGuardGenerationConfigs(
should_list_violated_codes=True,
explanation_position=ExplanationPosition.AFTER_DECISION
)

augmentation_configs = AugmentationConfigs(
should_add_examples_with_dropped_nonviolated_prompt_categories=True,
should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
)

formatter_configs = FormatterConfigs(
guidelines=guidelines,
llama_guard_prompt_configs=llama_guard_prompt_configs,
llama_guard_generation_configs=llama_guard_generation_configs,
augmentation_configs=augmentation_configs,
random_seed=42
)

# Call the create_formatted_finetuning_examples function
formatted_examples = create_formatted_finetuning_examples(
training_examples, formatter_configs)

# Print the formatted examples
print(formatted_examples)
albertodepaola marked this conversation as resolved.
Show resolved Hide resolved
Loading