Skip to content

Commit

Permalink
quick change
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Dec 19, 2023
1 parent 7e1336f commit 89ea1c5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 72 deletions.
125 changes: 63 additions & 62 deletions lm_human_preference_details/summarize/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,74 +375,75 @@ def evaluate(args, accelerator, tokenizer, model, dataloader):
accelerator.print(f"{train_accuracy=}, {scheduler.get_last_lr()=}, {update=}")
# if args.print_sample_output_freq > 0 and global_step % args.print_sample_output_freq == 0:

# evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader)
# for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows():
# writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step)
# accelerator.print(f"{split} accuracy: {row['accuracy']}")
# for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows():
# writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step)
# accelerator.print(f"{batch} accuracy: {row['accuracy']}")
# for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows():
# writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step)
# accelerator.print(f"{confi} confidence: {row['accuracy']}")
# writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step)
# accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}")
# if accelerator.is_main_process:
# os.makedirs(f"eval_tables/{run_name}", exist_ok=True)
# evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv")
# if args.track:
# wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update)
# torch.cuda.empty_cache()

# norm_dataset = load_dataset(args.task.query_dataset, split="train")
# norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"])
# norm_dataset = norm_dataset.shuffle(seed=local_seed)
# norm_dataloader = DataLoader(norm_dataset, batch_size=args.local_eval_batch_size)
# items = defaultdict(list)
# norm_dataloader = accelerator.prepare(norm_dataloader)
# with torch.no_grad():
# for data in tqdm(norm_dataloader):
# reference_responses = data["reference_response_token"].to(device, non_blocking=True)
# queries = data["query_token"].to(device, non_blocking=True)
# query_responses = torch.cat((queries, reference_responses), dim=1)
# predicted_reward = get_reward(model, query_responses, tokenizer)
# predicted_reward = accelerator.gather(predicted_reward)
# queries = accelerator.gather(queries)
# reference_responses = accelerator.gather(reference_responses)
# accelerator.print(predicted_reward.shape)
# for i in range(len(predicted_reward)):
# items["query"].append(tokenizer.decode(queries[i], skip_special_tokens=True))
# items["reference_response"].append(tokenizer.decode(reference_responses[i]))
# items["predicted_reward"].append(predicted_reward[i].item())
if args.run_eval:
evaluate_df = evaluate(args, accelerator, tokenizer, model, validation_dataloader)
for split, row in evaluate_df[["split", "accuracy"]].groupby(["split"]).mean().iterrows():
writer.add_scalar(f"eval/accuracy/{split}", row["accuracy"], global_step)
accelerator.print(f"{split} accuracy: {row['accuracy']}")
for batch, row in evaluate_df[["batch", "accuracy"]].groupby(["batch"]).mean().iterrows():
writer.add_scalar(f"eval/accuracy/{batch}", row["accuracy"], global_step)
accelerator.print(f"{batch} accuracy: {row['accuracy']}")
for confi, row in evaluate_df[["confidence", "accuracy"]].groupby(["confidence"]).mean().iterrows():
writer.add_scalar(f"eval/confidence/{confi}", row["accuracy"], global_step)
accelerator.print(f"{confi} confidence: {row['accuracy']}")
writer.add_scalar("eval/accuracy", evaluate_df["accuracy"].mean(), global_step)
accelerator.print(f"eval accuracy: {evaluate_df['accuracy'].mean()}")
if accelerator.is_main_process:
os.makedirs(f"eval_tables/{run_name}", exist_ok=True)
evaluate_df.to_csv(f"eval_tables/{run_name}/eval_{update}.csv")
if args.track:
wandb.log({"samples/query_responses": wandb.Table(dataframe=evaluate_df)}, step=update)
torch.cuda.empty_cache()

norm_dataset = load_dataset(args.task.query_dataset, split="train")
norm_dataset = norm_dataset.with_format("torch", columns=["query_token", "reference_response_token"])
norm_dataset = norm_dataset.shuffle(seed=local_seed)
norm_dataloader = DataLoader(norm_dataset, batch_size=args.local_eval_batch_size)
items = defaultdict(list)
norm_dataloader = accelerator.prepare(norm_dataloader)
with torch.no_grad():
for data in tqdm(norm_dataloader):
reference_responses = data["reference_response_token"].to(device, non_blocking=True)
queries = data["query_token"].to(device, non_blocking=True)
query_responses = torch.cat((queries, reference_responses), dim=1)
predicted_reward = get_reward(model, query_responses, tokenizer)
predicted_reward = accelerator.gather(predicted_reward)
queries = accelerator.gather(queries)
reference_responses = accelerator.gather(reference_responses)
accelerator.print(predicted_reward.shape)
for i in range(len(predicted_reward)):
items["query"].append(tokenizer.decode(queries[i], skip_special_tokens=True))
items["reference_response"].append(tokenizer.decode(reference_responses[i]))
items["predicted_reward"].append(predicted_reward[i].item())

# if accelerator.is_main_process:
# norm_df = pd.DataFrame(items)
# os.makedirs(f"eval_tables/{run_name}", exist_ok=True)
# norm_df.to_csv(f"eval_tables/{run_name}/eval_{update}_normalized.csv")
# if args.track:
# wandb.log({"samples/normalized": wandb.Table(dataframe=norm_df)}, step=update)
# stats = {
# "mean": norm_df["predicted_reward"].mean(),
# "std": norm_df["predicted_reward"].std(),
# "max": norm_df["predicted_reward"].max(),
# "min": norm_df["predicted_reward"].min()
# }
# for stat_name, stat_value in stats.items():
# writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step)
# accelerator.print(f"Normalized Reward {stat_name.capitalize()}: {stat_value}")
if accelerator.is_main_process:
norm_df = pd.DataFrame(items)
os.makedirs(f"eval_tables/{run_name}", exist_ok=True)
norm_df.to_csv(f"eval_tables/{run_name}/eval_{update}_normalized.csv")
if args.track:
wandb.log({"samples/normalized": wandb.Table(dataframe=norm_df)}, step=update)
stats = {
"mean": norm_df["predicted_reward"].mean(),
"std": norm_df["predicted_reward"].std(),
"max": norm_df["predicted_reward"].max(),
"min": norm_df["predicted_reward"].min()
}
for stat_name, stat_value in stats.items():
writer.add_scalar(f"eval/normalized_{stat_name}", stat_value, global_step)
accelerator.print(f"Normalized Reward {stat_name.capitalize()}: {stat_value}")

# save model
if args.output_dir and args.num_train_epochs > 0:
os.makedirs(os.path.dirname(args.output_dir), exist_ok=True)
time_tensor = torch.tensor(int(time.time()), device=device)
time_int = accelerator.gather(time_tensor).item() # avoid different timestamps across processes
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr__seed{args.seed}"
time_tensor = torch.tensor([int(time.time())], device=device)
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name

if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir, repo_id=repo_id)
if args.push_to_hub:
tokenizer.push_to_hub(repo_id, revision=str(time_int))
tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}")

unwrapped: PreTrainedModel = accelerator.unwrap_model(model)
accelerator.wait_for_everyone()
Expand All @@ -456,8 +457,8 @@ def evaluate(args, accelerator, tokenizer, model, dataloader):
repo_id=repo_id,
)
if args.push_to_hub:
unwrapped.push_to_hub(repo_id, revision=str(time_int), safe_serialization=False)
unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False)

if __name__ == "__main__":
args = tyro.cli(Args)
# train(args)
# if __name__ == "__main__":
# args = tyro.cli(Args)
# train(args)
28 changes: 18 additions & 10 deletions lm_human_preference_details/summarize/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,14 @@ class Args:
"""The batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)"""
batch_size: Optional[int] = None
"""The batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`)"""
local_eval_batch_size: int = 8
local_eval_batch_size: int = 4
"""per rank eval batch size"""
world_size: Optional[int] = None
"""The number of processes (GPUs) to use"""
num_train_epochs: int = 1
"""Number of epochs to train"""
num_updates: Optional[int] = None
"""The number of updates to train"""

# other args
base_model: str = "EleutherAI/pythia-160m"
"""the name of the pretrained model to use"""
Expand Down Expand Up @@ -238,7 +237,11 @@ def forward(model, query_responses, tokenizer):
configure_dropout(model_config, args.dropout_layer_keys, 0.0) # disable dropout
if accelerator.is_main_process:
pprint(model_config)
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(args.base_model, config=model_config, trust_remote_code=True)
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
args.base_model,
config=model_config,
trust_remote_code=True,
)
model.generation_config.eos_token_id = None # disable `pad_token_id` and `eos_token_id` because we just want to
model.generation_config.pad_token_id = None # generate tokens without truncation / padding
if args.optimizer == "adam":
Expand All @@ -249,7 +252,7 @@ def forward(model, query_responses, tokenizer):
args.scheduler,
optimizer=optimizer,
num_warmup_steps=args.warm_up_steps,
num_training_steps=args.num_updates,
num_training_steps=args.num_updates * args.num_train_epochs,
)

model, optimizer, dataloader, scheduler = accelerator.prepare(
Expand Down Expand Up @@ -303,6 +306,7 @@ def forward(model, query_responses, tokenizer):
writer.add_scalar("lr", scheduler.get_last_lr()[0], update)
accelerator.print(f"{loss.item()=}, {scheduler.get_last_lr()=}, {update=}")
break

if args.run_eval:
model.eval()
rouge_scores = collections.defaultdict(list)
Expand Down Expand Up @@ -345,9 +349,13 @@ def forward(model, query_responses, tokenizer):
decode_validation_queries = tokenizer.batch_decode(accelerator.gather(validation_queries))
decode_validation_query_responses = tokenizer.batch_decode(accelerator.gather(generated_responses))
decode_validation_reference_responses = tokenizer.batch_decode(
accelerator.gather(validation_reference_responses)
accelerator.gather(validation_reference_responses),
skip_special_tokens=True,
)
decode_validation_responses = tokenizer.batch_decode(
accelerator.gather(generated_responses[:, -args.task.response_length:]),
skip_special_tokens=True,
)
decode_validation_responses = tokenizer.batch_decode(accelerator.gather(generated_responses[:, -args.task.response_length:]))
rouge_score = rouge.compute(
predictions=decode_validation_responses, references=decode_validation_reference_responses
)
Expand Down Expand Up @@ -384,15 +392,15 @@ def forward(model, query_responses, tokenizer):
# save model
if args.output_dir:
os.makedirs(os.path.dirname(args.output_dir), exist_ok=True)
time_tensor = torch.tensor(int(time.time()), device=device)
time_tensor = torch.tensor([int(time.time())], device=device)
time_int = accelerator.gather(time_tensor)[0].item() # avoid different timestamps across processes
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr__seed{args.seed}"
repo_name = f"{args.base_model.replace('/', '_')}__{args.exp_name}__tldr"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name

if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir, repo_id=repo_id)
if args.push_to_hub:
tokenizer.push_to_hub(repo_id, revision=str(time_int))
tokenizer.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}")

unwrapped: PreTrainedModel = accelerator.unwrap_model(model)
accelerator.wait_for_everyone()
Expand All @@ -406,7 +414,7 @@ def forward(model, query_responses, tokenizer):
repo_id=repo_id,
)
if args.push_to_hub:
unwrapped.push_to_hub(repo_id, revision=str(time_int), safe_serialization=False)
unwrapped.push_to_hub(repo_id, revision=f"seed{args.seed}_{str(time_int)}", safe_serialization=False)

# if __name__ == "__main__":
# args = tyro.cli(Args)
Expand Down

0 comments on commit 89ea1c5

Please sign in to comment.