Skip to content

Commit

Permalink
push changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jan 6, 2024
1 parent 3f6d045 commit d12e443
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 60 deletions.
41 changes: 24 additions & 17 deletions lm_human_preference_details/summarize/ppo_left_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class Args:
"""the mini batch size across GPUs"""
local_eval_batch_size: int = 2
"""per rank eval batch size"""
local_rollout_forward_batch_size: int = 64
"""per rank no grad forward pass in the rollout phase"""

# other args
base_model: str = "EleutherAI/pythia-160m"
Expand Down Expand Up @@ -466,7 +468,8 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation
name=run_name,
save_code=True,
)
wandb.run.log_code(".")
file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"]
wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions]))
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
Expand Down Expand Up @@ -634,8 +637,8 @@ def repeat_generator():
values = []
scores = []
sequence_lengths = []
for i in range(0, queries.shape[0], args.local_eval_batch_size):
query = queries[i : i + args.local_eval_batch_size]
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = generate(
accelerator.unwrap_model(model).policy,
query,
Expand All @@ -649,12 +652,16 @@ def repeat_generator():
logits /= args.task.temperature + 1e-7
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del output, logits, all_logprob
torch.cuda.empty_cache()

ref_output = forward(ref_policy, query_response, tokenizer)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.task.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `truncate_token_id`
postprocessed_response = truncate_response(args, tokenizer, response)
Expand Down Expand Up @@ -684,8 +691,7 @@ def repeat_generator():
values = torch.cat(values, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
del (output, logits, all_logprob, logprob, ref_output)
del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score)
del (logprob, ref_logprob, full_value, value, score)
torch.cuda.empty_cache()

# Response Processing 3. filter response. Ensure that the sample contains truncate_token_id
Expand Down Expand Up @@ -766,14 +772,22 @@ def repeat_generator():
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange)
pg_loss = torch.max(pg_losses, pg_losses2).mean()
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
loss = pg_loss + args.ppo.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
with torch.no_grad():
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
# if ppo_epoch_idx == 0 and micro_batch_start == 0:
# torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4)
# if ppo_epoch_idx == 0:
Expand All @@ -794,14 +808,6 @@ def repeat_generator():
# # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]),
# })
# breakpoint()
with torch.no_grad():
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
if accelerator.is_main_process:
Expand Down Expand Up @@ -861,6 +867,7 @@ def repeat_generator():
if args.reward.use_adaptive_kl:
kl_ctl.update(mean_kl.item(), args.batch_size)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores
torch.cuda.empty_cache()

if args.run_eval:
eval_storage, eval_df = evaluate(
Expand Down
90 changes: 48 additions & 42 deletions lm_human_preference_details/summarize/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Args:
"""Which layers to apply dropout to"""
output_dir: str = "models/reward_model"
"""Where to save the model"""
label_dataset: str = "vwxyzjn/summarize_from_feedback_oai_preprocessing_pythia-160m_169"
label_dataset: str = "cleanrl/summarize_from_feedback_oai_preprocessing_1704563162"
"""the name of the dataset to use for labels in `https://huggingface.co/datasets/vwxyzjn/lm-human-preferences`"""
logsigmoid: bool = True
"""Whether to use log-sigmoid loss instead of cross-entropy loss"""
Expand Down Expand Up @@ -271,7 +271,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader):
args.batch_size = int(args.local_batch_size * args.world_size)

# load dataset
dataset = load_dataset(args.label_dataset, "comparisons", split="train")
dataset = load_dataset(args.label_dataset, split="train")
dataset = dataset.shuffle(seed=local_seed)
dataset = dataset.select(range(args.label.num_train))
dataset = dataset.with_format(
Expand All @@ -288,27 +288,31 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader):
],
)
dataloader = DataLoader(dataset, batch_size=args.local_micro_batch_size)
validation_dataset = load_dataset(args.label_dataset, "comparisons", split="validation").flatten()
validation_dataset = validation_dataset.with_format(
"torch",
columns=[
"query_token",
"choice",
"response0_token",
"query_response0_token",
"response1_token",
"query_response1_token",
"batch",
"split",
"extra.confidence",
"response0_policy",
"response1_policy",
"policies",
],
)
validation_dataloader = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size)
eval_datasets = []
eval_dataloaders = {}
for split in ["validation", "validation_cnndm"]:
validation_dataset = load_dataset(args.label_dataset, split=split).flatten()
validation_dataset = validation_dataset.with_format(
"torch",
columns=[
"query_token",
"choice",
"response0_token",
"query_response0_token",
"response1_token",
"query_response1_token",
"batch",
"split",
"extra.confidence",
"response0_policy",
"response1_policy",
"policies",
],
)
eval_datasets.append(validation_dataset)
eval_dataloaders[split] = DataLoader(validation_dataset, batch_size=args.local_eval_batch_size)
accelerator.print("The number of samples in validation_dataset", len(validation_dataset))
accelerator.print("The number of samples in dataset", len(dataset))
accelerator.print("The number of samples in validation_dataset", len(validation_dataset))
args.total_episodes = len(dataset)
args.num_updates = args.total_episodes // args.batch_size

Expand All @@ -328,7 +332,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader):
name=run_name,
save_code=True,
)
wandb.run.log_code(".")
file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"]
wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions]))
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
Expand Down Expand Up @@ -379,7 +384,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader):
deepspeed_states.deepspeed_config["train_micro_batch_size_per_gpu"] = args.local_micro_batch_size

model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
validation_dataloader = accelerator.prepare(validation_dataloader)
eval_dataloaders = {split: accelerator.prepare(eval_dataloader) for split, eval_dataloader in eval_dataloaders.items()}

accelerator.print("===training model===")
losses = torch.zeros((args.gradient_accumulation_steps,), device=device)
Expand Down Expand Up @@ -436,24 +441,25 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader):
)

if args.run_eval:
evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader)
for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows():
writer.add_scalar(f"eval/rm/accuracy/split/{split}", row["accuracy"], global_step)
accelerator.print(f"eval/rm/accuracy/split/{split}: {row['accuracy']}")
for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows():
writer.add_scalar(f"eval/rm/accuracy/batch/{batch}", row["accuracy"], global_step)
accelerator.print(f"eval/rm/accuracy/batch/{batch}: {row['accuracy']}")
for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows():
writer.add_scalar(f"eval/rm/accuracy/confidence/{confi}", row["accuracy"], global_step)
accelerator.print(f"eval/rm/accuracy/confidence/{confi}: {row['accuracy']}")
writer.add_scalar("eval/rm/accuracy", evaluate_df["accuracy"].mean(), global_step)
accelerator.print(f"eval/rm/accuracy: {evaluate_df['accuracy'].mean()}")
if accelerator.is_main_process:
os.makedirs(f"eval_tables/{run_name}", exist_ok=True)
evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv")
if args.track:
wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update)
torch.cuda.empty_cache()
for eval_split in eval_dataloaders:
evaluate_df = evaluate(args, accelerator, tokenizer, model, eval_dataloaders[eval_split])
for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows():
writer.add_scalar(f"eval/rm/{eval_split}/accuracy/split/{split}", row["accuracy"], global_step)
accelerator.print(f"eval/rm/{eval_split}/accuracy/split/{split}: {row['accuracy']}")
for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows():
writer.add_scalar(f"eval/rm/{eval_split}/accuracy/batch/{batch}", row["accuracy"], global_step)
accelerator.print(f"eval/rm/{eval_split}/accuracy/batch/{batch}: {row['accuracy']}")
for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows():
writer.add_scalar(f"eval/rm/{eval_split}/accuracy/confidence/{confi}", row["accuracy"], global_step)
accelerator.print(f"eval/rm/{eval_split}/accuracy/confidence/{confi}: {row['accuracy']}")
writer.add_scalar(f"eval/rm/{eval_split}/accuracy", evaluate_df["accuracy"].mean(), global_step)
accelerator.print(f"eval/rm/{eval_split}/accuracy: {evaluate_df['accuracy'].mean()}")
if accelerator.is_main_process:
os.makedirs(f"eval_tables/{run_name}", exist_ok=True)
evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{eval_split}_{update}.csv")
if args.track:
wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update)
torch.cuda.empty_cache()

norm_dataset = load_dataset(args.task.query_dataset, split="train")
norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"])
Expand Down
3 changes: 2 additions & 1 deletion lm_human_preference_details/summarize/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader, generation_c
name=run_name,
save_code=True,
)
wandb.run.log_code(".")
file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"]
wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions]))
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
Expand Down

0 comments on commit d12e443

Please sign in to comment.