Skip to content

Commit

Permalink
quick change
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Jan 11, 2024
1 parent 6f6490f commit 2166b4f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
52 changes: 32 additions & 20 deletions lm_human_preference_details/summarize/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class Args:
"""the mini batch size across GPUs"""
local_eval_batch_size: int = 2
"""per rank eval batch size"""
local_rollout_forward_batch_size: int = 64
"""per rank no grad forward pass in the rollout phase"""

# other args
base_model: str = "EleutherAI/pythia-160m"
Expand Down Expand Up @@ -462,7 +464,8 @@ def evaluate(args: Args, reward_model, policy, tokenizer, dataloader, generation
name=run_name,
save_code=True,
)
wandb.run.log_code(".")
file_extensions = [".toml", ".lock", ".py", ".sh", ".yaml"]
wandb.run.log_code(".", include_fn=lambda path: any([path.endswith(ext) for ext in file_extensions]))
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
Expand Down Expand Up @@ -582,6 +585,7 @@ def repeat_generator():

accelerator.print("===training policy===")
global_step = 0
start_time = time.time()
stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
Expand Down Expand Up @@ -630,8 +634,8 @@ def repeat_generator():
values = []
scores = []
sequence_lengths = []
for i in range(0, queries.shape[0], args.local_eval_batch_size):
query = queries[i : i + args.local_eval_batch_size]
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = generate(
accelerator.unwrap_model(model).policy,
query,
Expand All @@ -645,12 +649,16 @@ def repeat_generator():
logits /= args.task.temperature + 1e-7
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del output, logits, all_logprob
torch.cuda.empty_cache()

ref_output = forward(ref_policy, query_response, tokenizer)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.task.temperature + 1e-7
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del ref_output, ref_logits, ref_all_logprob
torch.cuda.empty_cache()

# Response Processing 1. truncate response after the first occurrence of `truncate_token_id`
postprocessed_response = truncate_response(args, tokenizer, response)
Expand Down Expand Up @@ -678,8 +686,7 @@ def repeat_generator():
values = torch.cat(values, 0)
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
del (output, logits, all_logprob, logprob, ref_output)
del (ref_logits, ref_all_logprob, ref_logprob, full_value, value, score)
del (logprob, ref_logprob, full_value, value, score)
torch.cuda.empty_cache()

# Response Processing 3. filter response. Ensure that the sample contains truncate_token_id
Expand Down Expand Up @@ -760,14 +767,22 @@ def repeat_generator():
pg_losses = -mb_advantage * ratio
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.ppo.cliprange, 1.0 + args.ppo.cliprange)
pg_loss = torch.max(pg_losses, pg_losses2).mean()
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
loss = pg_loss + args.ppo.vf_coef * vf_loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
with torch.no_grad():
pg_clipfrac = (pg_losses2 > pg_losses).float().mean()
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
approxkl = 0.5 * (logprobs_diff**2).mean()
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
# if ppo_epoch_idx == 0 and micro_batch_start == 0:
# torch.testing.assert_close(ratio, torch.zeros_like(ratio) + 1, atol=1e-4, rtol=1e-4)
# if ppo_epoch_idx == 0:
Expand All @@ -788,14 +803,6 @@ def repeat_generator():
# # "entropy": masked_mean(entropy, ~padding_mask[micro_batch_inds]),
# })
# breakpoint()
with torch.no_grad():
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
gradient_accumulation_idx += 1
minibatch_idx += 1
if accelerator.is_main_process:
Expand Down Expand Up @@ -852,9 +859,13 @@ def repeat_generator():
writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update)
writer.add_scalar("ppo/lr", lrnow, update)
writer.add_scalar("ppo/episode", global_step, update)
eps = int(global_step / (time.time() - start_time))
writer.add_scalar("ppo/eps", eps, update)
accelerator.print("ppo/eps", eps, update)
if args.reward.use_adaptive_kl:
kl_ctl.update(mean_kl.item(), args.batch_size)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores
torch.cuda.empty_cache()

if args.run_eval:
eval_storage, eval_df = evaluate(
Expand All @@ -866,9 +877,10 @@ def repeat_generator():
validation_generation_config,
sampling=False,
)
eval_df.to_csv(f"runs/{run_name}/table.csv")
if accelerator.is_main_process and args.track:
wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update)
if accelerator.is_main_process:
eval_df.to_csv(f"runs/{run_name}/table.csv")
if args.track:
wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update)

# save model
if args.output_dir and args.num_train_epochs > 0:
Expand Down
11 changes: 8 additions & 3 deletions lm_human_preference_details/summarize/ppo_left_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ def repeat_generator():

accelerator.print("===training policy===")
global_step = 0
start_time = time.time()
stats_shape = (args.ppo.noptepochs, args.nminibatches, args.gradient_accumulation_steps)
approxkl_stats = torch.zeros(stats_shape, device=device)
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
Expand Down Expand Up @@ -864,6 +865,9 @@ def repeat_generator():
writer.add_scalar("ppo/val/num_eos_tokens", (responses == tokenizer.eos_token_id).sum().item(), update)
writer.add_scalar("ppo/lr", lrnow, update)
writer.add_scalar("ppo/episode", global_step, update)
eps = int(global_step / (time.time() - start_time))
writer.add_scalar("ppo/eps", eps, update)
accelerator.print("ppo/eps", eps, update)
if args.reward.use_adaptive_kl:
kl_ctl.update(mean_kl.item(), args.batch_size)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores
Expand All @@ -879,9 +883,10 @@ def repeat_generator():
validation_generation_config,
sampling=False,
)
eval_df.to_csv(f"runs/{run_name}/table.csv")
if accelerator.is_main_process and args.track:
wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update)
if accelerator.is_main_process:
eval_df.to_csv(f"runs/{run_name}/table.csv")
if args.track:
wandb.log({"eval/query_responses": wandb.Table(dataframe=eval_df)}, step=update)

# save model
if args.output_dir and args.num_train_epochs > 0:
Expand Down
3 changes: 2 additions & 1 deletion lm_human_preference_details/summarize/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Args:
"""The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)"""
batch_size: Optional[int] = None
"""The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)"""
local_eval_batch_size: int = 8
local_eval_batch_size: int = 1
"""per rank eval batch size"""

# other args
Expand Down Expand Up @@ -459,6 +459,7 @@ def evaluate(args: Args, accelerator, tokenizer, model, dataloader):
evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{eval_split}_{update}.csv")
if args.track:
wandb.log({f"samples/{eval_split}/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update)
del evaluate_df
torch.cuda.empty_cache()

norm_dataset = load_dataset(args.task.query_dataset, split="train")
Expand Down

0 comments on commit 2166b4f

Please sign in to comment.