diff --git a/configs/ilql_config.yml b/configs/ilql_config.yml index 9bfe83563..207090300 100644 --- a/configs/ilql_config.yml +++ b/configs/ilql_config.yml @@ -1,7 +1,7 @@ model: model_path: "gpt2" tokenizer_path: "gpt2" - model_type: "ILQLModel" + model_type: "AccelerateILQLModel" num_layers_unfrozen: -1 train: @@ -19,7 +19,7 @@ train: checkpoint_interval: 1000 eval_interval: 128 - pipeline: "OfflinePipeline" + pipeline: "PromptPipeline" orchestrator: "OfflineOrchestrator" seed: 1000 diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index c526f564e..ea7d5c581 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -19,8 +19,9 @@ train: checkpoint_interval: 10000 # checkpoint interval eval_interval: 16 # eval interval - pipeline: "PPOPipeline" # prompt pipeline to load + pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load + entity_name: "jon-tow" method: name: 'ppoconfig' # Name of RL method config diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml index bd822f61d..ebd92b3a8 100644 --- a/configs/ppo_gptj.yml +++ b/configs/ppo_gptj.yml @@ -19,7 +19,7 @@ train: checkpoint_interval: 1000000 # checkpoint interval eval_interval: 16 # eval interval - pipeline: "PPOPipeline" # prompt pipeline to load + pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load method: diff --git a/configs/ray_tune_configs/ppo_config.yml b/configs/ray_tune_configs/ppo_config.yml index 8cc3d0fb9..0afabeee5 100644 --- a/configs/ray_tune_configs/ppo_config.yml +++ b/configs/ray_tune_configs/ppo_config.yml @@ -35,7 +35,7 @@ train: checkpoint_interval: 10000 # checkpoint interval eval_interval: 4 # eval interval - pipeline: "PPOPipeline" # prompt pipeline to load + pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load project_name: "trlx-hyperopt-bohb" diff --git a/configs/test_config.yml b/configs/test_config.yml index abf9e6df1..05e9f9576 100644 --- a/configs/test_config.yml +++ b/configs/test_config.yml @@ -19,7 +19,7 @@ train: checkpoint_interval: 10000 # checkpoint interval eval_interval: 128 # eval interval - pipeline: "PPOPipeline" # prompt pipeline to load + pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load method: diff --git a/examples/experiments/grounded_program_synthesis/config/trlx_ppo_config.yml b/examples/experiments/grounded_program_synthesis/config/trlx_ppo_config.yml index c76aacef0..f9e45cd7c 100644 --- a/examples/experiments/grounded_program_synthesis/config/trlx_ppo_config.yml +++ b/examples/experiments/grounded_program_synthesis/config/trlx_ppo_config.yml @@ -20,7 +20,7 @@ train: checkpoint_interval: 1000000 # checkpoint interval eval_interval: 16 # eval interval - pipeline: "PPOPipeline" # prompt pipeline to load + pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load method: diff --git a/trlx/trlx.py b/trlx/trlx.py index 9fa6bcdb9..c98a527ec 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -1,20 +1,8 @@ import os from typing import Callable, Iterable, List, Optional, Tuple -from accelerate import Accelerator - from trlx.data.configs import TRLConfig - -from trlx.model.accelerate_ilql_model import AccelerateILQLModel -from trlx.model.accelerate_ppo_model import AcceleratePPOModel - -from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator -from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator - -from trlx.pipeline.offline_pipeline import PromptPipeline -from trlx.utils.loading import get_model, get_orchestrator - -import ray +from trlx.utils.loading import get_model, get_orchestrator, get_pipeline def train( @@ -38,6 +26,7 @@ def train( prompts (List[str]): Prompts to sample off from during online training eval_prompts (List[str]): Prompts to periodically validate training on metric_fn (Optional[Callable[List[str], List[float]]]): Function to compute statistics on validation samples + config (Optional[TRLConfig]): TRL configuration object to override default settings split_token (Optional[str]): Split samples in the dataset on prompts and continuations logit_mask (Optional[List]): Bigram masking matrix """ @@ -49,7 +38,7 @@ def train( if model_path: config.model.model_path = model_path - model: AcceleratePPOModel = get_model(config.model.model_type)(config) + model = get_model(config.model.model_type)(config) batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) prompts = prompts or [model.tokenizer.bos_token] * batch_size @@ -57,12 +46,14 @@ def train( if eval_prompts is None: eval_prompts = prompts[:batch_size] - pipeline = PromptPipeline(prompts, model.tokenizer) - orch: PPOOrchestrator = get_orchestrator(config.train.orchestrator)( + pipeline = get_pipeline(config.train.pipeline)(prompts, model.tokenizer) + orch = get_orchestrator(config.train.orchestrator)( model, pipeline, reward_fn=reward_fn, chunk_size=config.method.chunk_size ) orch.make_experience(config.method.num_rollouts) - eval_pipeline = PromptPipeline(eval_prompts, model.tokenizer) + eval_pipeline = get_pipeline(config.train.pipeline)( + eval_prompts, model.tokenizer + ) model.add_eval_pipeline(eval_pipeline) elif dataset is not None: @@ -79,7 +70,7 @@ def train( if model_path: config.model.model_path = model_path - model = AccelerateILQLModel( + model = get_model(config.model.model_type)( config=config, logit_mask=logit_mask, metric_fn=metric_fn, @@ -88,10 +79,13 @@ def train( batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) if eval_prompts is None: eval_prompts = [model.tokenizer.bos_token] * batch_size + eval_pipeline = get_pipeline(config.train.pipeline)( + eval_prompts, model.tokenizer + ) - eval_pipeline = PromptPipeline(eval_prompts, model.tokenizer) - - orch = OfflineOrchestrator(model, split_token=split_token) + orch = get_orchestrator(config.train.orchestrator)( + model, split_token=split_token + ) orch.make_experience(samples, rewards) model.add_eval_pipeline(eval_pipeline) diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 9fe7aa64c..4b603c80f 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -1,8 +1,18 @@ from typing import Callable +# Register load models via module import from trlx.model import _MODELS +from trlx.model.accelerate_ilql_model import AccelerateILQLModel +from trlx.model.accelerate_ppo_model import AcceleratePPOModel + +# Register load orchestrators via module import from trlx.orchestrator import _ORCH +from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator +from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator + +# Register load pipelines via module import from trlx.pipeline import _DATAPIPELINE +from trlx.pipeline.offline_pipeline import PromptPipeline def get_model(name: str) -> Callable: