Skip to content

Commit

Permalink
support parallel reward function
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingru committed Dec 6, 2023
1 parent aa85988 commit 6838580
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,11 @@ def evaluate(self): # noqa: C901
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes)

columns_data = [str_prompts, str_outputs]
if not self.config.train.reward_only_in_main_process:
columns_data = self.accelerator.gather_for_metrics(columns_data)
if self.accelerator.is_main_process:
columns = ["prompt", "output"]
columns_data = [str_prompts, str_outputs]
if not self.config.train.reward_only_in_main_process:
columns_data = self.accelerator.gather_for_metrics(columns_data)

metadata, *xs = all_metadata
for k in metadata:
Expand All @@ -445,12 +447,12 @@ def evaluate(self): # noqa: C901
else:
rewards = torch.tensor(rewards, dtype=float)

if not self.config.train.reward_only_in_main_process:
rewards = self.accelerator.gather(rewards)
if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process:
rewards = self.accelerator.gather(rewards)
mean_reward = rewards.mean().item()

columns = ["prompt", "output", "reward"]
columns.append("reward")
if not isinstance(rewards, list):
rewards = rewards.tolist()
columns_data.append(rewards)
Expand Down

0 comments on commit 6838580

Please sign in to comment.