diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index 2f02ec5..30e3893 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -1,7 +1,7 @@ -from collections import defaultdict import os import random import time +from collections import defaultdict from dataclasses import asdict, dataclass, field from types import SimpleNamespace from typing import List, Literal, Optional @@ -15,16 +15,23 @@ import tyro from accelerate import Accelerator from accelerate.state import AcceleratorState -from accelerate.utils import DistributedDataParallelKwargs, gather_object +from accelerate.utils import gather_object from datasets import load_dataset from rich.console import Console from rich.pretty import pprint from rich.table import Table from torch import optim -from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import AutoConfig, AutoModel, AutoTokenizer, get_scheduler, PreTrainedModel, PretrainedConfig +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + PretrainedConfig, + PreTrainedModel, + get_scheduler, +) @dataclass @@ -126,7 +133,9 @@ class Args: # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" output_dir: str = "models/reward_policy" """Where to save the model""" @@ -164,13 +173,16 @@ def layer_init(layer, std=np.sqrt(2), bias_const=0.0): class ScalarModelConfig(PretrainedConfig): - model_type = 'scalar_model' + model_type = "scalar_model" + def __init__(self, base_model: str = "gpt2", **kwargs): super().__init__(**kwargs) self.base_model = base_model + class ScalarModel(PreTrainedModel): config_class = ScalarModelConfig + def __init__(self, config: ScalarModelConfig): super().__init__(config) self.config = config @@ -203,10 +215,7 @@ def get_reward(model, query_responses, tokenizer): return_dict=True, output_hidden_states=True, ) - sequence_lengths = ( - torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( - query_responses.device - ) + sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device) # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 return reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths] @@ -258,10 +267,26 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): dataset = load_dataset(args.label_dataset, "comparisons", split="train") dataset = dataset.shuffle(seed=local_seed) dataset = dataset.select(range(args.labels.num_train)) - dataset = dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"]) + dataset = dataset.with_format( + "torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split"] + ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() - validation_dataset = validation_dataset.with_format("torch", columns=["query_token", "choice", "response0_token", "response1_token", "batch", "split", "extra.confidence", "response0_policy", "response1_policy", "policies"]) + validation_dataset = validation_dataset.with_format( + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "response1_token", + "batch", + "split", + "extra.confidence", + "response0_policy", + "response1_policy", + "policies", + ], + ) validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) accelerator.print("The number of samples in dataset", len(dataset)) accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) @@ -426,7 +451,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): "mean": norm_df["predicted_reward"].mean(), "std": norm_df["predicted_reward"].std(), "max": norm_df["predicted_reward"].max(), - "min": norm_df["predicted_reward"].min() + "min": norm_df["predicted_reward"].min(), } for stat_name, stat_value in stats.items(): writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step) @@ -436,7 +461,7 @@ def evaluate(args, accelerator, tokenizer, model, dataloader): if args.output_dir and args.num_train_epochs > 0: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) time_tensor = torch.tensor([int(time.time())], device=device) - time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index 32b1645..8f84bc1 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -1,5 +1,4 @@ import collections -import functools import os import random import time @@ -13,7 +12,6 @@ import torch import torch.optim as optim import tyro -from tqdm import tqdm from accelerate import Accelerator from datasets import load_dataset from rich.console import Console @@ -23,6 +21,7 @@ from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -122,7 +121,9 @@ class Args: # other args base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" output_dir: str = "models/sft_model" """Where to save the model""" @@ -255,9 +256,7 @@ def forward(model, query_responses, tokenizer): num_training_steps=args.num_updates * args.num_train_epochs, ) - model, optimizer, dataloader, scheduler = accelerator.prepare( - model, optimizer, dataloader, scheduler - ) + model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) validation_dataloader = accelerator.prepare(validation_dataloader) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -305,7 +304,6 @@ def forward(model, query_responses, tokenizer): writer.add_scalar("loss", accelerator.gather(loss_stats).mean().item(), update) writer.add_scalar("lr", scheduler.get_last_lr()[0], update) accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}") - break if args.run_eval: model.eval() @@ -319,9 +317,7 @@ def forward(model, query_responses, tokenizer): with torch.no_grad(): validation_reference_responses = validation_data["reference_response_token"].to(device, non_blocking=True) validation_queries = validation_data["query_token"].to(device, non_blocking=True) - validation_query_reference_responses = torch.cat( - (validation_queries, validation_reference_responses), dim=1 - ) + validation_query_reference_responses = torch.cat((validation_queries, validation_reference_responses), dim=1) validation_output = forward(model, validation_query_reference_responses, tokenizer) validation_labels = validation_query_reference_responses.masked_fill( @@ -353,7 +349,7 @@ def forward(model, query_responses, tokenizer): skip_special_tokens=True, ) decode_validation_responses = tokenizer.batch_decode( - accelerator.gather(generated_responses[:, -args.task.response_length:]), + accelerator.gather(generated_responses[:, -args.task.response_length :]), skip_special_tokens=True, ) rouge_score = rouge.compute( @@ -393,7 +389,7 @@ def forward(model, query_responses, tokenizer): if args.output_dir: os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) time_tensor = torch.tensor([int(time.time())], device=device) - time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes + time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name diff --git a/lm_human_preference_details/tldr_dataset.py b/lm_human_preference_details/tldr_dataset.py index bac3363..d945428 100644 --- a/lm_human_preference_details/tldr_dataset.py +++ b/lm_human_preference_details/tldr_dataset.py @@ -1,15 +1,16 @@ -from dataclasses import dataclass +import multiprocessing import os +from dataclasses import dataclass from typing import Dict, Optional -from datasets import load_dataset -from rich.pretty import pprint -from transformers import AutoTokenizer -import tyro -import multiprocessing import matplotlib.pyplot as plt import pandas as pd +import tyro +from datasets import load_dataset from huggingface_hub import HfApi +from rich.pretty import pprint +from transformers import AutoTokenizer + api = HfApi() @@ -20,11 +21,13 @@ --max-sft-response-length=53 \ --max-rm-response-length=169 """ + + @dataclass class Args: - base_model: str = "gpt2" # EleutherAI/pythia-160m - max_sft_response_length: int = 48 # 53 - max_rm_response_length: int = 153 # 169 + base_model: str = "gpt2" # EleutherAI/pythia-160m + max_sft_response_length: int = 48 # 53 + max_rm_response_length: int = 153 # 169 hf_entity: str = None @@ -36,7 +39,7 @@ class TaskQueryHParams: ] = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" # if underlying dataset yields dicts, can format arbitrarily truncate_field: Optional[str] = "post" truncate_text: Optional[str] = "\n" - padding: Optional[str] = " " # empty spaces + padding: Optional[str] = " " # empty spaces pad_side: Optional[str] = "left" @@ -138,7 +141,9 @@ def process_query_data(x): } sft_ds = sft_ds.map(process_query_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) - sft_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}") + sft_ds.push_to_hub( + f"{args.hf_entity}/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_sft_response_length}" + ) label_ds = load_dataset("openai/summarize_from_feedback", "comparisons") @@ -168,7 +173,9 @@ def process_response_data(x): } label_ds = label_ds.map(process_response_data, load_from_cache_file=False, num_proc=multiprocessing.cpu_count()) - label_ds.push_to_hub(f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}") + label_ds.push_to_hub( + f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}" + ) os.makedirs("dataset_visuals", exist_ok=True) # visualize token length distribution @@ -183,10 +190,10 @@ def process_response_data(x): offset = len(sft_ds) for i, key in enumerate(label_ds.keys()): df = label_ds[key].to_pandas() - axs[2*i + offset].hist(df["response0_token_len"], bins=100) - axs[2*i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") - axs[2*i + offset + 1].hist(df["response1_token_len"], bins=100) - axs[2*i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") + axs[2 * i + offset].hist(df["response0_token_len"], bins=100) + axs[2 * i + offset].set_title(f"{key} split: response0 token length\nmax_length={max(df['response0_token_len'])}") + axs[2 * i + offset + 1].hist(df["response1_token_len"], bins=100) + axs[2 * i + offset + 1].set_title(f"{key} split: response1 token length\nmax_length={max(df['response1_token_len'])}") fig.suptitle(f"{args.base_model} Tokenizer: Token length distribution") fig.tight_layout() fig.savefig("dataset_visuals/token_len.png") @@ -244,4 +251,3 @@ def process_response_data(x): repo_id=f"{args.hf_entity}/summarize_from_feedback_oai_preprocessing_{args.base_model.split('/')[-1]}_{args.max_rm_response_length}", repo_type="dataset", ) - diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py index 0e28c49..da288a9 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -142,7 +141,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -356,8 +357,9 @@ def whiten(values, shift_mean=True): def masked_mean(x, mask): return (x.sum(-1) / (~mask).sum(-1)).mean() + def masked_var(x, mask): - return (x**2).sum(-1) / (~mask).sum(-1) - masked_mean(x, mask)**2 + return (x**2).sum(-1) / (~mask).sum(-1) - masked_mean(x, mask) ** 2 def masked_whiten(values, mask, shift_mean=True): @@ -367,7 +369,7 @@ def masked_whiten(values, mask, shift_mean=True): if not shift_mean: whitened += mean return whitened - + def masked_mean(values, mask, axis=None): """Compute mean of tensor with a masked values.""" @@ -376,6 +378,7 @@ def masked_mean(values, mask, axis=None): else: return (values * mask).sum() / mask.sum() + def masked_var(values, mask, unbiased=True): """Compute variance of tensor with masked values.""" mean = masked_mean(values, mask) @@ -497,7 +500,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -640,9 +647,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -826,7 +831,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -903,8 +910,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_losses1 = torch.square(vpred - mb_return) vf_losses2 = torch.square(vpredclipped - mb_return) vf_loss_max = torch.max(vf_losses1, vf_losses2) - - + # vf_loss = 0.5 * vf_loss_max.mean() vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask[micro_batch_inds]) vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~padding_mask[micro_batch_inds]) @@ -927,7 +933,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * masked_mean((logprobs_diff**2), ~padding_mask[micro_batch_inds]) # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -951,7 +957,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = masked_mean(entropy, padding_mask[micro_batch_inds]) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = masked_mean( + entropy, padding_mask[micro_batch_inds] + ) ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 @@ -961,13 +969,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # breakpoint() with torch.no_grad(): diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py index 927c6bc..a8d03c4 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate1.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -142,7 +141,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -419,12 +420,13 @@ def get_reward(reward_model, query_responses, tokenizer): return_dict=True, output_hidden_states=True, ) - sequence_lengths = ( - torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to( - query_responses.device - ) + sequence_lengths = (torch.eq(query_responses, tokenizer.pad_token_id).long().argmax(-1) - 1).to(query_responses.device) # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -463,6 +465,7 @@ def truncate_response(args, tokenizer, responses): def masked_mean(x, mask): return (x.sum(-1) / (~mask).sum(-1)).mean() + # def train(args: Args): if __name__ == "__main__": args = tyro.cli(Args) @@ -640,7 +643,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_clipfrac_stats = torch.zeros(stats_shape, device=device) entropy_stats = torch.zeros(stats_shape, device=device) ratio_stats = torch.zeros(stats_shape, device=device) - + model.train() for update in range(1, args.ppo.num_updates + 1): global_step += 1 * args.ppo.batch_size @@ -661,60 +664,550 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well generation_config, ) if args.task.response_length != 53: - query_responses = torch.tensor([[ 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 209, 209, 209, 209, - 209, 209, 209, 209, 209, 209, 6971, 7941, 1703, 37, - 1433, 27, 391, 16, 22842, 16458, 187, 187, 53, 43561, - 27, 3189, 544, 1348, 278, 62, 5816, 619, 806, 385, - 544, 1797, 269, 62, 846, 608, 2607, 273, 2740, 598, - 15, 187, 187, 15743, 27, 24387, 39714, 187, 6300, 15950, - 436, 1501, 562, 627, 816, 281, 1339, 352, 562, 273, - 619, 985, 15, 187, 2598, 309, 452, 644, 13597, 436, - 3226, 313, 2577, 806, 19609, 15, 309, 369, 617, 806, - 10, 323, 495, 1107, 15, 844, 574, 271, 13103, 673, - 285, 4536, 7227, 35267, 285, 37616, 15, 496, 253, 990, - 13, 352, 1904, 626, 789, 562, 15, 187, 42, 3260, - 309, 7636, 617, 285, 703, 7636, 479, 533, 1841, 816, - 1904, 626, 789, 562, 1955, 281, 1097, 4858, 4606, 15, - 187, 187, 2598, 352, 556, 644, 2761, 608, 2607, 15, - 309, 1694, 689, 253, 31056, 673, 273, 619, 1495, 534, - 369, 1501, 2740, 598, 273, 806, 374, 2607, 15, 209, - 187, 4125, 846, 608, 2607, 13, 309, 816, 2985, 617, - 15, 187, 42, 5476, 627, 11210, 626, 644, 247, 2014, - 835, 309, 6468, 626, 1869, 670, 617, 2568, 15, 23385, - 50276, 187, 42, 871, 309, 10095, 626, 3057, 617, 285, - 309, 1353, 3965, 2119, 703, 1912, 626, 3057, 479, 2057, - 534, 310, 323, 253, 1805, 15, 187, 1231, 6468, 626, - 13452, 323, 5046, 374, 2607, 32, 1633, 751, 326, 15, - 187, 43688, 13, 309, 816, 4571, 626, 6016, 352, 10542, - 285, 3261, 387, 776, 7963, 327, 619, 17899, 7963, 534, - 309, 1620, 755, 327, 15, 187, 1147, 369, 5322, 281, - 923, 617, 2454, 969, 285, 30774, 336, 253, 1711, 1897, - 15, 187, 1147, 369, 5322, 281, 923, 253, 9097, 359, - 1097, 2389, 1024, 3811, 342, 617, 2021, 15, 187, 34937, - 512, 608, 2607, 13, 619, 5249, 5055, 598, 15, 309, - 1694, 247, 14892, 209, 187, 36421, 598, 247, 2257, 273, - 2583, 285, 858, 1841, 1475, 253, 2419, 309, 6468, 626, - 644, 2104, 281, 3966, 3966, 15, 187, 1989, 309, 816, - 2985, 617, 15, 187, 42, 871, 703, 434, 2509, 973, - 13, 3164, 1805, 685, 1078, 15, 187, 2513, 352, 816, - 479, 32, 209, 187, 25954, 6701, 323, 634, 673, 4361, - 436, 285, 11435, 634, 5701, 15, 187, 187, 14135, 28, - 4976, 27, 6365, 619, 806, 19609, 13, 9377, 598, 13, - 309, 2985, 617, 533, 1053, 626, 3057, 617, 285, 12371, - 604, 352, 434, 816, 479, 15, 0,]], device=device) + query_responses = torch.tensor( + [ + [ + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 209, + 6971, + 7941, + 1703, + 37, + 1433, + 27, + 391, + 16, + 22842, + 16458, + 187, + 187, + 53, + 43561, + 27, + 3189, + 544, + 1348, + 278, + 62, + 5816, + 619, + 806, + 385, + 544, + 1797, + 269, + 62, + 846, + 608, + 2607, + 273, + 2740, + 598, + 15, + 187, + 187, + 15743, + 27, + 24387, + 39714, + 187, + 6300, + 15950, + 436, + 1501, + 562, + 627, + 816, + 281, + 1339, + 352, + 562, + 273, + 619, + 985, + 15, + 187, + 2598, + 309, + 452, + 644, + 13597, + 436, + 3226, + 313, + 2577, + 806, + 19609, + 15, + 309, + 369, + 617, + 806, + 10, + 323, + 495, + 1107, + 15, + 844, + 574, + 271, + 13103, + 673, + 285, + 4536, + 7227, + 35267, + 285, + 37616, + 15, + 496, + 253, + 990, + 13, + 352, + 1904, + 626, + 789, + 562, + 15, + 187, + 42, + 3260, + 309, + 7636, + 617, + 285, + 703, + 7636, + 479, + 533, + 1841, + 816, + 1904, + 626, + 789, + 562, + 1955, + 281, + 1097, + 4858, + 4606, + 15, + 187, + 187, + 2598, + 352, + 556, + 644, + 2761, + 608, + 2607, + 15, + 309, + 1694, + 689, + 253, + 31056, + 673, + 273, + 619, + 1495, + 534, + 369, + 1501, + 2740, + 598, + 273, + 806, + 374, + 2607, + 15, + 209, + 187, + 4125, + 846, + 608, + 2607, + 13, + 309, + 816, + 2985, + 617, + 15, + 187, + 42, + 5476, + 627, + 11210, + 626, + 644, + 247, + 2014, + 835, + 309, + 6468, + 626, + 1869, + 670, + 617, + 2568, + 15, + 23385, + 50276, + 187, + 42, + 871, + 309, + 10095, + 626, + 3057, + 617, + 285, + 309, + 1353, + 3965, + 2119, + 703, + 1912, + 626, + 3057, + 479, + 2057, + 534, + 310, + 323, + 253, + 1805, + 15, + 187, + 1231, + 6468, + 626, + 13452, + 323, + 5046, + 374, + 2607, + 32, + 1633, + 751, + 326, + 15, + 187, + 43688, + 13, + 309, + 816, + 4571, + 626, + 6016, + 352, + 10542, + 285, + 3261, + 387, + 776, + 7963, + 327, + 619, + 17899, + 7963, + 534, + 309, + 1620, + 755, + 327, + 15, + 187, + 1147, + 369, + 5322, + 281, + 923, + 617, + 2454, + 969, + 285, + 30774, + 336, + 253, + 1711, + 1897, + 15, + 187, + 1147, + 369, + 5322, + 281, + 923, + 253, + 9097, + 359, + 1097, + 2389, + 1024, + 3811, + 342, + 617, + 2021, + 15, + 187, + 34937, + 512, + 608, + 2607, + 13, + 619, + 5249, + 5055, + 598, + 15, + 309, + 1694, + 247, + 14892, + 209, + 187, + 36421, + 598, + 247, + 2257, + 273, + 2583, + 285, + 858, + 1841, + 1475, + 253, + 2419, + 309, + 6468, + 626, + 644, + 2104, + 281, + 3966, + 3966, + 15, + 187, + 1989, + 309, + 816, + 2985, + 617, + 15, + 187, + 42, + 871, + 703, + 434, + 2509, + 973, + 13, + 3164, + 1805, + 685, + 1078, + 15, + 187, + 2513, + 352, + 816, + 479, + 32, + 209, + 187, + 25954, + 6701, + 323, + 634, + 673, + 4361, + 436, + 285, + 11435, + 634, + 5701, + 15, + 187, + 187, + 14135, + 28, + 4976, + 27, + 6365, + 619, + 806, + 19609, + 13, + 9377, + 598, + 13, + 309, + 2985, + 617, + 533, + 1053, + 626, + 3057, + 617, + 285, + 12371, + 604, + 352, + 434, + 816, + 479, + 15, + 0, + ] + ], + device=device, + ) context_length = queries.shape[1] responses = query_responses[:, context_length:] @@ -778,15 +1271,10 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # responses not passing that filter will receive a low (fixed) score # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) - - - + # TODO: reverse it back # scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - - - - + torch.cuda.empty_cache() # 4. compute rewards @@ -848,16 +1336,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well advantages = torch.stack(advantages_reversed[::-1], axis=1) returns = advantages + values - - - # TODO: reverse it back # advantages = whiten(advantages) - - - - return_mean, return_var = returns.mean(), returns.var() value_mean, value_var = values.mean(), values.var() @@ -896,8 +1377,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well vf_losses1 = torch.square(vpred - mb_return) vf_losses2 = torch.square(vpredclipped - mb_return) vf_loss_max = torch.max(vf_losses1, vf_losses2) - - + vf_loss = 0.5 * vf_loss_max.mean() # vf_loss = 0.5 * masked_mean(vf_loss_max, padding_mask[micro_batch_inds]) @@ -920,22 +1400,24 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # approxkl = 0.5 * masked_mean((logprobs_diff**2), padding_mask[micro_batch_inds]) # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - pprint({ - "responses": responses, - "values": values, - "rewards": rewards, - "scores": scores, - "advantages": advantages, - "ratio": ratio, - "pg_losses": pg_losses, - "approxkl": approxkl, - "pg_loss": pg_loss, - "pg_clipfrac": pg_clipfrac, - "ratio": ratio.mean(), - "vf_loss": vf_loss, - "vf_clipfrac": vf_clipfrac, - "entropy": entropy.mean(), - }) + pprint( + { + "responses": responses, + "values": values, + "rewards": rewards, + "scores": scores, + "advantages": advantages, + "ratio": ratio, + "pg_losses": pg_losses, + "approxkl": approxkl, + "pg_loss": pg_loss, + "pg_clipfrac": pg_clipfrac, + "ratio": ratio.mean(), + "vf_loss": vf_loss, + "vf_clipfrac": vf_clipfrac, + "entropy": entropy.mean(), + } + ) with torch.no_grad(): approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac @@ -946,20 +1428,20 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 raise - # minibatch_idx += 1 - # if accelerator.is_main_process: - # console.print( - # f"ppo_epoch_idx", - # ppo_epoch_idx, - # "approxkl", - # approxkl_stats[:ppo_epoch_idx+1].mean().item(), - # "pg_loss", - # pg_loss_stats[:ppo_epoch_idx+1].mean().item(), - # "pg_clipfrac", - # pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), - # "ratio", - # ratio_stats[:ppo_epoch_idx+1].mean().item(), - # ) + # minibatch_idx += 1 + # if accelerator.is_main_process: + # console.print( + # f"ppo_epoch_idx", + # ppo_epoch_idx, + # "approxkl", + # approxkl_stats[:ppo_epoch_idx+1].mean().item(), + # "pg_loss", + # pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + # "pg_clipfrac", + # pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + # "ratio", + # ratio_stats[:ppo_epoch_idx+1].mean().item(), + # ) with torch.no_grad(): if not args.deepspeed: # for some reason there is a OOM with the `writer.add_histogram` diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py index b5b6eef..0bf5f2e 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate3.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -595,9 +600,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -687,7 +690,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -729,7 +732,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -773,7 +776,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -833,7 +838,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -864,7 +869,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -897,13 +902,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py index 49d023a..ee18d2f 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate4.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -595,9 +600,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -642,7 +645,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -696,7 +699,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -704,7 +707,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -738,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -782,7 +785,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -842,7 +847,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -873,7 +878,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -906,13 +911,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py index ea30982..ae59f50 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate5_load_critic.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -594,9 +599,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -641,7 +644,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -703,7 +706,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -737,7 +740,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -781,7 +784,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -841,7 +846,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -872,7 +877,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -905,13 +910,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py index 7331a66..cc7fef7 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate6_correct_reward_index.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -594,9 +599,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -641,7 +644,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -703,7 +706,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -738,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -784,7 +787,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -844,7 +849,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -875,7 +880,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -908,13 +913,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py index a316656..41622b5 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate7_correct_reward_index_no_load_critic.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -594,9 +599,7 @@ def forward(policy, query_responses, tokenizer): eval_ds_config["zero_optimization"] = { "stage": 3, "stage3_param_persistence_threshold": 1e4, - "offload_param": { - "device": "cpu" - } + "offload_param": {"device": "cpu"}, } accelerator.print(f"{eval_ds_config=}") reward_model, *_ = deepspeed.initialize(model=reward_model, config=eval_ds_config) @@ -641,7 +644,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -695,7 +698,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -703,7 +706,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -738,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -784,7 +787,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -844,7 +849,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -875,7 +880,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -908,13 +913,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint() diff --git a/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py b/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py index a1e860a..bb84138 100644 --- a/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py +++ b/lm_human_preference_details/train_policy_accelerate_summarize_separate8_correct_reward_index_deepspeed3.py @@ -33,7 +33,6 @@ GenerationConfig, ) - INVALID_LOGPROB = 1.0 @@ -49,8 +48,8 @@ class RewardHParams: adaptive_kl: Optional[AdaptiveKLParams] = field(default_factory=AdaptiveKLParams) trained_model: Optional[str] = "" label_dataset: tyro.conf.Suppress[Optional[str]] = None - dataset_mean: float = 0. - dataset_std: float = 1. + dataset_mean: float = 0.0 + dataset_std: float = 1.0 kl_coef: float = 0.15 @@ -144,7 +143,9 @@ class Args: base_model: str = "EleutherAI/pythia-160m" """the name of the pretrained model to use""" - dropout_layer_keys: List[str] = field(default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"]) + dropout_layer_keys: List[str] = field( + default_factory=lambda: ["attn_pdrop", "embd_pdrop", "resid_pdrop", "summary_first_dropout"] + ) """Which layers to apply dropout to""" deepspeed: bool = False """Whether to use deepspeed to train the model""" @@ -449,7 +450,11 @@ def get_reward(reward_model, query_responses, tokenizer): # ) # print(f"======={sequence_lengths1=} {sequence_lengths=}") # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 - return reward_logits, reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), sequence_lengths + return ( + reward_logits, + reward_logits[torch.arange(reward_logits.size(0), device=reward_logits.device), sequence_lengths].squeeze(-1), + sequence_lengths, + ) def forward(policy, query_responses, tokenizer): @@ -569,6 +574,7 @@ def forward(policy, query_responses, tokenizer): if args.deepspeed: deepspeed_states = AcceleratorState().deepspeed_plugin from deepspeed.ops.adam import DeepSpeedCPUAdam + # if deepspeed_states.deepspeed_config['zero_optimization']['offload_optimizer']['device'] in ('none', None): # return optim.AdamW(params, eps=self.opt.eps, betas=(self.opt.beta1, self.opt.beta2)) optimizer = DeepSpeedCPUAdam(model.parameters(), lr=args.ppo.lr, eps=args.ppo.eps) @@ -590,7 +596,6 @@ def forward(policy, query_responses, tokenizer): # deepspeed_states = AcceleratorState().deepspeed_plugin # # deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.ppo.local_micro_batch_size # # deepspeed_states.deepspeed_config["checkpoint"] = {"use_node_local_storage": True} - # offload = False # eval_ds_config = { # "train_micro_batch_size_per_gpu": deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"], @@ -644,6 +649,7 @@ def forward(policy, query_responses, tokenizer): def repeat_generator(): # TODO: ideally we shuffle the dataloader as well while True: yield from dataloader + iter_dataloader = iter(repeat_generator()) sample_validation_inds = np.arange(args.ppo.batch_size) @@ -662,12 +668,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well data = next(iter_dataloader) queries = data["query_token"].to(device) accelerator.print(f"==={queries.shape=}, {queries.dtype}") - accelerator.print(f"==={sample_validation_query_reference_responses.shape=}, {sample_validation_query_reference_responses.dtype}") + accelerator.print( + f"==={sample_validation_query_reference_responses.shape=}, {sample_validation_query_reference_responses.dtype}" + ) _, sample_validation_reference_scores, _ = get_reward( reward_model, sample_validation_query_reference_responses, tokenizer ) - kl_ctl = AdaptiveKLController(args.rewards.kl_coef, hparams=args.rewards.adaptive_kl) # WARNING: even with `max_new_tokens` and `min_new_tokens` set to the same value, the number of tokens generated # may not be the same. TODO: investigate further, we just want to generate a fixed number of tokens @@ -680,7 +687,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well do_sample=True, ) # use the same `0.01` temperature for validation response generation https://github.com/openai/summarize-from-feedback/blob/700967448d10004279f138666442bf1497d0e705/exps/sample.py#L27 - validation_generation_config= GenerationConfig( + validation_generation_config = GenerationConfig( max_new_tokens=args.task.response_length, min_new_tokens=args.task.response_length, temperature=(0.01 + 1e-7), @@ -734,7 +741,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # TODO: do I do this with query response or post-processed query response? output = forward(accelerator.unwrap_model(model).policy, query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 all_logprobs = F.log_softmax(logits, dim=-1) logprobs = torch.gather(all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del output, logits, all_logprobs @@ -742,7 +749,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well ref_output = forward(ref_policy, query_responses, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= (args.task.temperature + 1e-7) + ref_logits /= args.task.temperature + 1e-7 ref_all_logprobs = F.log_softmax(ref_logits, dim=-1) ref_logprobs = torch.gather(ref_all_logprobs, 2, responses.unsqueeze(-1)).squeeze(-1) del ref_output, ref_logits, ref_all_logprobs @@ -777,7 +784,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well # only query humans on responses that pass that filter contain_pad_token = torch.any(postprocessed_responses == tokenizer.pad_token_id, dim=-1) scores = torch.where(contain_pad_token, scores, torch.full_like(scores, args.task.penalty_reward_value)) - + # TODO: do we need to deal with penalty values? # penalty_values = torch.full_like(values, 0) # penalty_values[:,-1] += args.task.penalty_reward_value @@ -823,7 +830,9 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well if accelerator.is_main_process: all_sample_validation_df.to_json(f"runs/{run_name}/table.json") if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update) + wandb.log( + {"samples/query_responses": wandb.Table(dataframe=all_sample_validation_df)}, step=update + ) # print_rich_table("stuff", all_sample_validation_df[:4], console) except Exception as e: @@ -883,7 +892,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well output, vpred_temp = forward(model, mb_query_responses, tokenizer) logits = output.logits[:, context_length - 1 : -1] - logits /= (args.task.temperature + 1e-7) + logits /= args.task.temperature + 1e-7 new_all_logprobs = F.log_softmax(logits, dim=-1) new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1) # new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) @@ -914,7 +923,7 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well approxkl = 0.5 * (logprobs_diff**2).mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) - # if ppo_epoch_idx == 0: + # if ppo_epoch_idx == 0: # pprint({ # # "responses": responses, # # "values": values, @@ -947,13 +956,13 @@ def repeat_generator(): # TODO: ideally we shuffle the dataloader as well f"ppo_epoch_idx", ppo_epoch_idx, "approxkl", - approxkl_stats[:ppo_epoch_idx+1].mean().item(), + approxkl_stats[: ppo_epoch_idx + 1].mean().item(), "pg_loss", - pg_loss_stats[:ppo_epoch_idx+1].mean().item(), + pg_loss_stats[: ppo_epoch_idx + 1].mean().item(), "pg_clipfrac", - pg_clipfrac_stats[:ppo_epoch_idx+1].mean().item(), + pg_clipfrac_stats[: ppo_epoch_idx + 1].mean().item(), "ratio", - ratio_stats[:ppo_epoch_idx+1].mean().item(), + ratio_stats[: ppo_epoch_idx + 1].mean().item(), ) # raise # breakpoint()