From bcd237f1e94c84c5c9f5a4086bab34c0946e3fa7 Mon Sep 17 00:00:00 2001 From: Max <56548574+maxreciprocate@users.noreply.github.com> Date: Wed, 11 Oct 2023 19:09:46 +0300 Subject: [PATCH] feat: add rejection finetuning trainer (#554) * feat: add rejection finetuning trainer * style: satisfy flake * fix(rft_trainer): broadcast scores to all ranks * feat(rft_trainer): dedup & clip thresholds for quantized rewards * config(rft_randomwalks): lower `total_steps`, keep 1 improve step * fix(rft_trainer): handle prompt duplicates, due to `drop_last=False` * feat(examples): add `rft_sentiments` example * style: satisfy black --- examples/randomwalks/rft_randomwalks.py | 64 ++++++++ examples/rft_sentiments.py | 96 ++++++++++++ trlx/trainer/accelerate_rft_trainer.py | 197 ++++++++++++++++++++++++ trlx/trlx.py | 2 +- trlx/utils/loading.py | 1 + 5 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 examples/randomwalks/rft_randomwalks.py create mode 100644 examples/rft_sentiments.py create mode 100644 trlx/trainer/accelerate_rft_trainer.py diff --git a/examples/randomwalks/rft_randomwalks.py b/examples/randomwalks/rft_randomwalks.py new file mode 100644 index 000000000..94a6203e3 --- /dev/null +++ b/examples/randomwalks/rft_randomwalks.py @@ -0,0 +1,64 @@ +import trlx +from examples.randomwalks import generate_random_walks +from trlx.data.default_configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.trainer.accelerate_rft_trainer import RFTConfig + +default_config = TRLConfig( + train=TrainConfig( + seq_length=10, + epochs=100, + total_steps=1000, + batch_size=100, + checkpoint_interval=1000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateRFTTrainer", + checkpoint_dir="checkpoints/randomwalks", + ), + model=ModelConfig(model_path="CarperAI/randomwalks", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="CarperAI/randomwalks", truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3.0e-4, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=0)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=3.0e-4)), + method=RFTConfig( + name="RFTConfig", + n_generations_per_prompt=100, + start_percentile=0.9, + end_percentile=0.95, + n_improve_steps=1, + gen_kwargs=dict( + max_new_tokens=9, + top_k=0, + top_p=1.0, + temperature=1.0, + do_sample=True, + ), + ), +) + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) + + trlx.train( + reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"], + prompts=prompts, + eval_prompts=prompts, + metric_fn=lambda samples, **kwargs: metric_fn(samples), + config=config, + ) + + +if __name__ == "__main__": + import json + import sys + + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/examples/rft_sentiments.py b/examples/rft_sentiments.py new file mode 100644 index 000000000..a835055df --- /dev/null +++ b/examples/rft_sentiments.py @@ -0,0 +1,96 @@ +# This script trains a model to output positive reviews +# using rejection finetuning with a sentiment classifier reward function. +import json +import os +import sys +from typing import List + +import torch +from datasets import load_dataset +from transformers import pipeline + +import trlx +from trlx.data.default_configs import ( + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) +from trlx.trainer.accelerate_rft_trainer import RFTConfig + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +default_config = TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=32, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateRFTTrainer", + ), + model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)), + method=RFTConfig( + name="RFTConfig", + n_generations_per_prompt=4, + start_percentile=0.9, + end_percentile=0.95, + n_improve_steps=1, + gen_kwargs=dict( + max_new_tokens=40, + top_k=0, + top_p=1.0, + temperature=1.0, + do_sample=True, + ), + ), +) + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def reward_fn(samples: List[str], **kwargs) -> List[float]: + sentiments = list(map(get_positive_score, sentiment_fn(samples))) + return sentiments + + # Take few words off of movies reviews as prompts + imdb = load_dataset("imdb", split="train[:512]") + prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] + + trlx.train( + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=["I don't know much about Hungarian underground"] * 256, + config=config, + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/trainer/accelerate_rft_trainer.py b/trlx/trainer/accelerate_rft_trainer.py new file mode 100644 index 000000000..6dde427fc --- /dev/null +++ b/trlx/trainer/accelerate_rft_trainer.py @@ -0,0 +1,197 @@ +import itertools +from collections import defaultdict +from dataclasses import dataclass + +import numpy as np +import torch +import wandb +from tqdm import tqdm +from transformers import AutoModelForCausalLM, PretrainedConfig + +from trlx.data.configs import TRLConfig +from trlx.data.method_configs import MethodConfig, register_method +from trlx.pipeline.offline_pipeline import PromptPipeline +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer + + +@dataclass +@register_method +class RFTConfig(MethodConfig): + """ + Config for RFT training + + :param gen_kwargs: kwargs for generation + :type gen_kwargs: Dict[str, Any] + + :param start_percentile: percentile for the starting score threshold for each prompt used for the first improvement step + :type start_percentile: float + + :param end_percentile: percentile for the final score threshold for each prompt + :type end_percentile: float + + :param n_improve_steps: the number of improvement steps for each growth step with linearly increasing score threshold + :type n_improve_steps: int + + :param n_generations_per_prompt: number of generations to sample per each prompt per each growth step + :type n_generations_per_prompt: int + """ + + gen_kwargs: dict + start_percentile: float = 0.7 + end_percentile: float = 0.95 + n_improve_steps: int = 4 + n_generations_per_prompt: int = 32 + + +@register_trainer +class AccelerateRFTTrainer(AccelerateRLTrainer): + def __init__(self, config: TRLConfig, **kwargs): + super().__init__(config, **kwargs) + + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + self.generate_experience_kwargs = None + + def get_arch(self, config): + from_fn = AutoModelForCausalLM.from_pretrained + if issubclass(type(config.model.model_path), PretrainedConfig): + from_fn = AutoModelForCausalLM.from_config + + model = from_fn(config.model.model_path) + + if config.model.peft_config is not None: + # Initialize the peft adapter + import peft + + peft_config = config.model.peft_config + if not isinstance(peft_config, peft.PeftConfig): + if isinstance(peft_config, dict): + peft_config = peft.get_peft_config(peft_config) + else: + raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") + model = peft.get_peft_model(model, peft_config) + if self.accelerator.is_main_process: + model.print_trainable_parameters() + + return model + + def loss(self, batch): + labels = batch.input_ids.clone() + loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss + stats = {"loss": loss.item()} + + return loss, stats + + def create_train_dataloader(self): + return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size)) + + def prepare_learning(self): + self.epoch_count = 0 + self.iter_count = 0 + self.n_inner_epochs = 1 + # because of variable number of samples per each improvement steps + # there is no way to get the estimate, so here it's just copied from the config + self.total_steps = self.config.train.total_steps + + self.generations_per_prompt = defaultdict(list) + + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + self.model, self.opt, self.eval_dataloader = self.accelerator.prepare(self.model, self.opt, eval_dataloader) + + self.make_experience() + + def add_prompt_pipeline(self, pipeline: PromptPipeline): + """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" + prompt_dataloader = pipeline.create_loader(self.config.train.batch_size) + self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + + def post_epoch_callback(self): + self.make_experience() + self.epoch_count += 1 + + def make_experience(self): # noqa: + if self.epoch_count % self.config.method.n_improve_steps == 0: + # generate n samples for each prompt in the prompt_dataloader + generations = [] + for batch in tqdm(self.prompt_dataloader, desc="Generating", disable=not self.accelerator.is_main_process): + for _ in range(self.config.method.n_generations_per_prompt): + samples = self.generate(**batch) + str_samples, str_prompts, str_outputs = self.decode(batch.input_ids, samples, append_eos_token=True) + generations.extend({"prompt": p, "output": o} for p, o in zip(str_prompts, str_outputs)) + + if torch.distributed.is_initialized(): + all_generations = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(all_generations, generations) + generations = list(itertools.chain(*all_generations)) + + # score the generations + if self.accelerator.is_main_process: + all_scores = self.reward_fn( + samples=[x["prompt"] + x["output"] for x in generations], + prompts=[x["prompt"] for x in generations], + outputs=[x["output"] for x in generations], + ) + + all_scores = torch.tensor(all_scores, device=self.accelerator.device) + else: + all_scores = torch.zeros(len(generations), device=self.accelerator.device) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(all_scores, src=0) + scores = all_scores + else: + scores = all_scores + + for g, s in zip(generations, scores): + self.generations_per_prompt[g["prompt"]].append({"output": g["output"], "score": s.item()}) + + scores = [[x["score"] for x in self.generations_per_prompt[p]] for p in self.generations_per_prompt] + + percentile_delta = ( + self.config.method.end_percentile - self.config.method.start_percentile + ) / self.config.method.n_improve_steps + percentile = self.config.method.start_percentile + percentile_delta * ( + self.epoch_count % self.config.method.n_improve_steps + ) + thresholds = np.array([np.quantile(np.array(scores), percentile) for scores in scores]) + # corner case for quantized rewards: don't include the min values, but don't exclude the max values + thresholds = np.clip(thresholds, thresholds.min() + 1e-3, thresholds.max() - 1e-3) + + # filter out the generations with a score below the percentile per prompt + samples_selected = [] + for prompt, threshold in zip(self.generations_per_prompt, thresholds): + for x in self.generations_per_prompt[prompt]: + if x["score"] >= threshold: + samples_selected.append([prompt, x["output"]]) + + # deduplicate the samples + samples_selected = list({tuple(x) for x in samples_selected}) + + self.accelerator.log( + { + "scores_per_single_prompt": wandb.Histogram(scores[0]), + "thresholds": wandb.Histogram(thresholds), + "scores_mean": np.mean(np.hstack(scores)), + "scores_dist": wandb.Histogram(np.hstack(scores)), + "len_samples_selected": len(samples_selected), + "samples_per_single_prompt": wandb.Table( + data=list( + zip( + [x[0] for x in samples_selected[:128]], + [x[1] for x in samples_selected[:128]], + ) + ), + columns=["prompt", "output"], + ), + }, + step=self.iter_count, + ) + + if len(samples_selected): + self.store = PromptPipeline( + samples_selected, max_prompt_length=2048, tokenizer=self.tokenizer, add_special_tokens=True + ) diff --git a/trlx/trlx.py b/trlx/trlx.py index 13ee5daaa..d724a9f24 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -87,7 +87,7 @@ def train( # noqa: C901 batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] - # Online training against a reward function (e.g. PPO) + # Online training against a reward function (e.g. PPO, RFT) if reward_fn: prompts = prompts or [trainer.tokenizer.bos_token] * batch_size diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 9c7dccf76..97f1c0534 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -8,6 +8,7 @@ from trlx.trainer import _TRAINERS, register_trainer from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer +from trlx.trainer.accelerate_rft_trainer import AccelerateRFTTrainer from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer try: