diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index acf5c256e..f004e0544 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -351,7 +351,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq scores = all_scores scores_mask = scores != -np.inf - str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + if self.config.train.reward_only_in_main_process: + str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True) + else: + str_samples, str_prompts, str_outputs = all_str_samples, all_str_prompts, all_str_outputs # Pad the sample outputs outputs = self.tokenizer(str_outputs).input_ids