From d283ee282ab78aa0b7f2ed3be4b2b069769a42a9 Mon Sep 17 00:00:00 2001 From: Jingru Date: Wed, 15 Nov 2023 05:51:25 +0000 Subject: [PATCH] feat: support parallel reward function --- trlx/trainer/accelerate_ppo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index acf5c256e..f004e0544 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -351,7 +351,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = all_scores scores_mask = scores != -np.inf - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + if self.config.train.reward_only_in_main_process: + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + else: + str_samples, str_prompts, str_outputs = all_str_samples, all_str_prompts, all_str_outputs # Pad the sample outputs outputs = self.tokenizer(str_outputs).input_ids