From d12e44312542f8207d850f683cd42c01f8ebe41f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 6 Jan 2024 21:37:55 +0000 Subject: [PATCH] push changes --- .../summarize/ppo_left_padding.py | 41 +++++---- .../summarize/reward.py | 90 ++++++++++--------- lm_human_preference_details/summarize/sft.py | 3 +- 3 files changed, 74 insertions(+), 60 deletions(-) diff --git a/lm_human_preference_details/summarize/ppo_left_padding.py b/lm_human_preference_details/summarize/ppo_left_padding.py index c621821..9067f2e 100644 --- a/lm_human_preference_details/summarize/ppo_left_padding.py +++ b/lm_human_preference_details/summarize/ppo_left_padding.py @@ -168,6 +168,8 @@ class Args: """the mini batch size across GPUs""" local_eval_batch_size: int = 2 """per rank eval batch size""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" # other args base_model: str = "EleutherAI/pythia-160m" @@ -466,7 +468,8 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", @@ -634,8 +637,8 @@ def repeat_generator(): values = [] scores = [] sequence_lengths = [] - for i in range(0, queries.shape[0], args.local_eval_batch_size): - query = queries[i : i + args.local_eval_batch_size] + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] query_response = generate( accelerator.unwrap_model(model).policy, query, @@ -649,12 +652,16 @@ def repeat_generator(): logits /= args.task.temperature + 1e-7 all_logprob = F.log_softmax(logits, dim=-1) logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del output, logits, all_logprob + torch.cuda.empty_cache() ref_output = forward(ref_policy, query_response, tokenizer) ref_logits = ref_output.logits[:, context_length - 1 : -1] ref_logits /= args.task.temperature + 1e-7 ref_all_logprob = F.log_softmax(ref_logits, dim=-1) ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() # Response Processing 1. truncate response after the first occurrence of `truncate_token_id` postprocessed_response = truncate_response(args, tokenizer, response) @@ -684,8 +691,7 @@ def repeat_generator(): values = torch.cat(values, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - del (output, logits, all_logprob, logprob, ref_output) - del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score) + del (logprob, ref_logprob, full_value, value, score) torch.cuda.empty_cache() # Response Processing 3. filter response. Ensure that the sample contains truncate_token_id @@ -766,14 +772,22 @@ def repeat_generator(): pg_losses = -mb_advantage * ratio pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange) pg_loss = torch.max(pg_losses, pg_losses2).mean() - pg_clipfrac = (pg_losses2 > pg_losses).float().mean() loss = pg_loss + args.ppo.vf_coef * vf_loss accelerator.backward(loss) optimizer.step() optimizer.zero_grad() - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() + with torch.no_grad(): + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() # if ppo_epoch_idx == 0 and micro_batch_start == 0: # torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4) # if ppo_epoch_idx == 0: @@ -794,14 +808,6 @@ def repeat_generator(): # # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]), # }) # breakpoint() - with torch.no_grad(): - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() gradient_accumulation_idx += 1 minibatch_idx += 1 if accelerator.is_main_process: @@ -861,6 +867,7 @@ def repeat_generator(): if args.reward.use_adaptive_kl: kl_ctl.update(mean_kl.item(), args.batch_size) del kl, mean_kl, mean_entropy, mean_non_score_reward, scores + torch.cuda.empty_cache() if args.run_eval: eval_storage, eval_df = evaluate( diff --git a/lm_human_preference_details/summarize/reward.py b/lm_human_preference_details/summarize/reward.py index d99717c..a41bea4 100644 --- a/lm_human_preference_details/summarize/reward.py +++ b/lm_human_preference_details/summarize/reward.py @@ -140,7 +140,7 @@ class Args: """Which layers to apply dropout to""" output_dir: str = "models/reward_model" """Where to save the model""" - label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169" + label_dataset: str = "cleanrl/summarize_from_feedback_oai_preprocessing_1704563162" """the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`""" logsigmoid: bool = True """Whether to use log-sigmoid loss instead of cross-entropy loss""" @@ -271,7 +271,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): args.batch_size = int(args.local_batch_size * args.world_size) # load dataset - dataset = load_dataset(args.label_dataset, "comparisons", split="train") + dataset = load_dataset(args.label_dataset, split="train") dataset = dataset.shuffle(seed=local_seed) dataset = dataset.select(range(args.label.num_train)) dataset = dataset.with_format( @@ -288,27 +288,31 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ], ) dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size) - validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten() - validation_dataset = validation_dataset.with_format( - "torch", - columns=[ - "query_token", - "choice", - "response0_token", - "query_response0_token", - "response1_token", - "query_response1_token", - "batch", - "split", - "extra.confidence", - "response0_policy", - "response1_policy", - "policies", - ], - ) - validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + eval_datasets = [] + eval_dataloaders = {} + for split in ["validation", "validation_cnndm"]: + validation_dataset = load_dataset(args.label_dataset, split=split).flatten() + validation_dataset = validation_dataset.with_format( + "torch", + columns=[ + "query_token", + "choice", + "response0_token", + "query_response0_token", + "response1_token", + "query_response1_token", + "batch", + "split", + "extra.confidence", + "response0_policy", + "response1_policy", + "policies", + ], + ) + eval_datasets.append(validation_dataset) + eval_dataloaders[split] = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size) + accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) accelerator.print("The number of samples in dataset", len(dataset)) - accelerator.print("The number of samples in validation_dataset", len(validation_dataset)) args.total_episodes = len(dataset) args.num_updates = args.total_episodes // args.batch_size @@ -328,7 +332,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", @@ -379,7 +384,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) - validation_dataloader = accelerator.prepare(validation_dataloader) + eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()} accelerator.print("===training model===") losses = torch.zeros((args.gradient_accumulation_steps,), device=device) @@ -436,24 +441,25 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader): ) if args.run_eval: - evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader) - for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/split/{split}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/split/{split}: {row['accuracy']}") - for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/batch/{batch}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/batch/{batch}: {row['accuracy']}") - for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): - writer.add_scalar(f"eval/rm/accuracy/confidence/{confi}", row["accuracy"], global_step) - accelerator.print(f"eval/rm/accuracy/confidence/{confi}: {row['accuracy']}") - writer.add_scalar("eval/rm/accuracy", evaluate_df["accuracy"].mean(), global_step) - accelerator.print(f"eval/rm/accuracy: {evaluate_df['accuracy'].mean()}") - if accelerator.is_main_process: - os.makedirs(f"eval_tables/{run_name}", exist_ok=True) - evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv") - if args.track: - wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) - torch.cuda.empty_cache() + for eval_split in eval_dataloaders: + evaluate_df = evaluate(args, accelerator, tokenizer, model, eval_dataloaders[eval_split]) + for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/split/{split}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/split/{split}: {row['accuracy']}") + for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/batch/{batch}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/batch/{batch}: {row['accuracy']}") + for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows(): + writer.add_scalar(f"eval/rm/{eval_split}/accuracy/confidence/{confi}", row["accuracy"], global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy/confidence/{confi}: {row['accuracy']}") + writer.add_scalar(f"eval/rm/{eval_split}/accuracy", evaluate_df["accuracy"].mean(), global_step) + accelerator.print(f"eval/rm/{eval_split}/accuracy: {evaluate_df['accuracy'].mean()}") + if accelerator.is_main_process: + os.makedirs(f"eval_tables/{run_name}", exist_ok=True) + evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{eval_split}_{update}.csv") + if args.track: + wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update) + torch.cuda.empty_cache() norm_dataset = load_dataset(args.task.query_dataset, split="train") norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"]) diff --git a/lm_human_preference_details/summarize/sft.py b/lm_human_preference_details/summarize/sft.py index 3c20704..a37f387 100644 --- a/lm_human_preference_details/summarize/sft.py +++ b/lm_human_preference_details/summarize/sft.py @@ -318,7 +318,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c name=run_name, save_code=True, ) - wandb.run.log_code(".") + file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"] + wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions])) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters",