Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Dec 19, 2023
1 parent 89ea1c5 commit fceceaf
Show file tree
Hide file tree
Showing 11 changed files with 832 additions and 281 deletions.
53 changes: 39 additions & 14 deletions lm_human_preference_details/summarize/reward.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
import os
import random
import time
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from types import SimpleNamespace
from typing import List, Literal, Optional
Expand All @@ -15,16 +15,23 @@
import tyro
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import DistributedDataParallelKwargs, gather_object
from accelerate.utils import gather_object
from datasets import load_dataset
from rich.console import Console
from rich.pretty import pprint
from rich.table import Table
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer, get_scheduler, PreTrainedModel, PretrainedConfig
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
get_scheduler,
)


@dataclass
Expand Down Expand Up @@ -126,7 +133,9 @@ class Args:
# other args
base_model: str = "EleutherAI/pythia-160m"
"""the name of the pretrained model to use"""
dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"])
dropout_layer_keys: List[str] = field(
default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]
)
"""Which layers to apply dropout to"""
output_dir: str = "models/reward_policy"
"""Where to save the model"""
Expand Down Expand Up @@ -164,13 +173,16 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0):


class ScalarModelConfig(PretrainedConfig):
model_type = 'scalar_model'
model_type = "scalar_model"

def __init__(self, base_model: str = "gpt2", **kwargs):
super().__init__(**kwargs)
self.base_model = base_model


class ScalarModel(PreTrainedModel):
config_class = ScalarModelConfig

def __init__(self, config: ScalarModelConfig):
super().__init__(config)
self.config = config
Expand Down Expand Up @@ -203,10 +215,7 @@ def get_reward(model, query_responses, tokenizer):
return_dict=True,
output_hidden_states=True,
)
sequence_lengths = (
torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(
query_responses.device
)
sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device)
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths]

Expand Down Expand Up @@ -258,10 +267,26 @@ def evaluate(args, accelerator, tokenizer, model, dataloader):
dataset = load_dataset(args.label_dataset, "comparisons", split="train")
dataset = dataset.shuffle(seed=local_seed)
dataset = dataset.select(range(args.labels.num_train))
dataset = dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"])
dataset = dataset.with_format(
"torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]
)
dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size)
validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten()
validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra.confidence", "response0_policy", "response1_policy", "policies"])
validation_dataset = validation_dataset.with_format(
"torch",
columns=[
"query_token",
"choice",
"response0_token",
"response1_token",
"batch",
"split",
"extra.confidence",
"response0_policy",
"response1_policy",
"policies",
],
)
validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size)
accelerator.print("The number of samples in dataset", len(dataset))
accelerator.print("The number of samples in validation_dataset", len(validation_dataset))
Expand Down Expand Up @@ -426,7 +451,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader):
"mean": norm_df["predicted_reward"].mean(),
"std": norm_df["predicted_reward"].std(),
"max": norm_df["predicted_reward"].max(),
"min": norm_df["predicted_reward"].min()
"min": norm_df["predicted_reward"].min(),
}
for stat_name, stat_value in stats.items():
writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step)
Expand All @@ -436,7 +461,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader):
if args.output_dir and args.num_train_epochs > 0:
os.makedirs(os.path.dirname(args.output_dir), exist_ok=True)
time_tensor = torch.tensor([int(time.time())], device=device)
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name

Expand Down
20 changes: 8 additions & 12 deletions lm_human_preference_details/summarize/sft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import collections
import functools
import os
import random
import time
Expand All @@ -13,7 +12,6 @@
import torch
import torch.optim as optim
import tyro
from tqdm import tqdm
from accelerate import Accelerator
from datasets import load_dataset
from rich.console import Console
Expand All @@ -23,6 +21,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -122,7 +121,9 @@ class Args:
# other args
base_model: str = "EleutherAI/pythia-160m"
"""the name of the pretrained model to use"""
dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"])
dropout_layer_keys: List[str] = field(
default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]
)
"""Which layers to apply dropout to"""
output_dir: str = "models/sft_model"
"""Where to save the model"""
Expand Down Expand Up @@ -255,9 +256,7 @@ def forward(model, query_responses, tokenizer):
num_training_steps=args.num_updates * args.num_train_epochs,
)

model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
validation_dataloader = accelerator.prepare(validation_dataloader)
# WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated
# may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens
Expand Down Expand Up @@ -305,7 +304,6 @@ def forward(model, query_responses, tokenizer):
writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update)
writer.add_scalar("lr", scheduler.get_last_lr()[0], update)
accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}")
break

if args.run_eval:
model.eval()
Expand All @@ -319,9 +317,7 @@ def forward(model, query_responses, tokenizer):
with torch.no_grad():
validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True)
validation_queries = validation_data["query_token"].to(device, non_blocking=True)
validation_query_reference_responses = torch.cat(
(validation_queries, validation_reference_responses), dim=1
)
validation_query_reference_responses = torch.cat((validation_queries, validation_reference_responses), dim=1)

validation_output = forward(model, validation_query_reference_responses, tokenizer)
validation_labels = validation_query_reference_responses.masked_fill(
Expand Down Expand Up @@ -353,7 +349,7 @@ def forward(model, query_responses, tokenizer):
skip_special_tokens=True,
)
decode_validation_responses = tokenizer.batch_decode(
accelerator.gather(generated_responses[:, -args.task.response_length:]),
accelerator.gather(generated_responses[:, -args.task.response_length :]),
skip_special_tokens=True,
)
rouge_score = rouge.compute(
Expand Down Expand Up @@ -393,7 +389,7 @@ def forward(model, query_responses, tokenizer):
if args.output_dir:
os.makedirs(os.path.dirname(args.output_dir), exist_ok=True)
time_tensor = torch.tensor([int(time.time())], device=device)
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name

Expand Down
40 changes: 23 additions & 17 deletions lm_human_preference_details/tldr_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from dataclasses import dataclass
import multiprocessing
import os
from dataclasses import dataclass
from typing import Dict, Optional

from datasets import load_dataset
from rich.pretty import pprint
from transformers import AutoTokenizer
import tyro
import multiprocessing
import matplotlib.pyplot as plt
import pandas as pd
import tyro
from datasets import load_dataset
from huggingface_hub import HfApi
from rich.pretty import pprint
from transformers import AutoTokenizer

api = HfApi()


Expand All @@ -20,11 +21,13 @@
--max-sft-response-length=53 \
--max-rm-response-length=169
"""


@dataclass
class Args:
base_model: str = "gpt2" # EleutherAI/pythia-160m
max_sft_response_length: int = 48 # 53
max_rm_response_length: int = 153 # 169
base_model: str = "gpt2" # EleutherAI/pythia-160m
max_sft_response_length: int = 48 # 53
max_rm_response_length: int = 153 # 169
hf_entity: str = None


Expand All @@ -36,7 +39,7 @@ class TaskQueryHParams:
] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily
truncate_field: Optional[str] = "post"
truncate_text: Optional[str] = "\n"
padding: Optional[str] = " " # empty spaces
padding: Optional[str] = " " # empty spaces
pad_side: Optional[str] = "left"


Expand Down Expand Up @@ -138,7 +141,9 @@ def process_query_data(x):
}

sft_ds = sft_ds.map(process_query_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count())
sft_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}")
sft_ds.push_to_hub(
f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}"
)

label_ds = load_dataset("openai/summarize_from_feedback", "comparisons")

Expand Down Expand Up @@ -168,7 +173,9 @@ def process_response_data(x):
}

label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count())
label_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}")
label_ds.push_to_hub(
f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}"
)

os.makedirs("dataset_visuals", exist_ok=True)
# visualize token length distribution
Expand All @@ -183,10 +190,10 @@ def process_response_data(x):
offset = len(sft_ds)
for i, key in enumerate(label_ds.keys()):
df = label_ds[key].to_pandas()
axs[2*i + offset].hist(df["response0_token_len"], bins=100)
axs[2*i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}")
axs[2*i + offset + 1].hist(df["response1_token_len"], bins=100)
axs[2*i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}")
axs[2 * i + offset].hist(df["response0_token_len"], bins=100)
axs[2 * i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}")
axs[2 * i + offset + 1].hist(df["response1_token_len"], bins=100)
axs[2 * i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}")
fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution")
fig.tight_layout()
fig.savefig("dataset_visuals/token_len.png")
Expand Down Expand Up @@ -244,4 +251,3 @@ def process_response_data(x):
repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}",
repo_type="dataset",
)

Loading

0 comments on commit fceceaf

Please sign in to comment.