Skip to content

Commit

Permalink
Force class registry via imports (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-tow authored Nov 20, 2022
1 parent 12598e9 commit aafcae9
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 28 deletions.
4 changes: 2 additions & 2 deletions configs/ilql_config.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
model:
model_path: "gpt2"
tokenizer_path: "gpt2"
model_type: "ILQLModel"
model_type: "AccelerateILQLModel"
num_layers_unfrozen: -1

train:
Expand All @@ -19,7 +19,7 @@ train:
checkpoint_interval: 1000
eval_interval: 128

pipeline: "OfflinePipeline"
pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
seed: 1000

Expand Down
3 changes: 2 additions & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ train:
checkpoint_interval: 10000 # checkpoint interval
eval_interval: 16 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
entity_name: "jon-tow"

method:
name: 'ppoconfig' # Name of RL method config
Expand Down
2 changes: 1 addition & 1 deletion configs/ppo_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ train:
checkpoint_interval: 1000000 # checkpoint interval
eval_interval: 16 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load

method:
Expand Down
2 changes: 1 addition & 1 deletion configs/ray_tune_configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ train:
checkpoint_interval: 10000 # checkpoint interval
eval_interval: 4 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load
project_name: "trlx-hyperopt-bohb"

Expand Down
2 changes: 1 addition & 1 deletion configs/test_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ train:
checkpoint_interval: 10000 # checkpoint interval
eval_interval: 128 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load

method:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ train:
checkpoint_interval: 1000000 # checkpoint interval
eval_interval: 16 # eval interval

pipeline: "PPOPipeline" # prompt pipeline to load
pipeline: "PromptPipeline" # prompt pipeline to load
orchestrator: "PPOOrchestrator" # orchestrator to load

method:
Expand Down
36 changes: 15 additions & 21 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
import os
from typing import Callable, Iterable, List, Optional, Tuple

from accelerate import Accelerator

from trlx.data.configs import TRLConfig

from trlx.model.accelerate_ilql_model import AccelerateILQLModel
from trlx.model.accelerate_ppo_model import AcceleratePPOModel

from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator
from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator

from trlx.pipeline.offline_pipeline import PromptPipeline
from trlx.utils.loading import get_model, get_orchestrator

import ray
from trlx.utils.loading import get_model, get_orchestrator, get_pipeline


def train(
Expand All @@ -38,6 +26,7 @@ def train(
prompts (List[str]): Prompts to sample off from during online training
eval_prompts (List[str]): Prompts to periodically validate training on
metric_fn (Optional[Callable[List[str], List[float]]]): Function to compute statistics on validation samples
config (Optional[TRLConfig]): TRL configuration object to override default settings
split_token (Optional[str]): Split samples in the dataset on prompts and continuations
logit_mask (Optional[List]): Bigram masking matrix
"""
Expand All @@ -49,20 +38,22 @@ def train(
if model_path:
config.model.model_path = model_path

model: AcceleratePPOModel = get_model(config.model.model_type)(config)
model = get_model(config.model.model_type)(config)

batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))
prompts = prompts or [model.tokenizer.bos_token] * batch_size

if eval_prompts is None:
eval_prompts = prompts[:batch_size]

pipeline = PromptPipeline(prompts, model.tokenizer)
orch: PPOOrchestrator = get_orchestrator(config.train.orchestrator)(
pipeline = get_pipeline(config.train.pipeline)(prompts, model.tokenizer)
orch = get_orchestrator(config.train.orchestrator)(
model, pipeline, reward_fn=reward_fn, chunk_size=config.method.chunk_size
)
orch.make_experience(config.method.num_rollouts)
eval_pipeline = PromptPipeline(eval_prompts, model.tokenizer)
eval_pipeline = get_pipeline(config.train.pipeline)(
eval_prompts, model.tokenizer
)
model.add_eval_pipeline(eval_pipeline)

elif dataset is not None:
Expand All @@ -79,7 +70,7 @@ def train(
if model_path:
config.model.model_path = model_path

model = AccelerateILQLModel(
model = get_model(config.model.model_type)(
config=config,
logit_mask=logit_mask,
metric_fn=metric_fn,
Expand All @@ -88,10 +79,13 @@ def train(
batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))
if eval_prompts is None:
eval_prompts = [model.tokenizer.bos_token] * batch_size
eval_pipeline = get_pipeline(config.train.pipeline)(
eval_prompts, model.tokenizer
)

eval_pipeline = PromptPipeline(eval_prompts, model.tokenizer)

orch = OfflineOrchestrator(model, split_token=split_token)
orch = get_orchestrator(config.train.orchestrator)(
model, split_token=split_token
)
orch.make_experience(samples, rewards)
model.add_eval_pipeline(eval_pipeline)

Expand Down
10 changes: 10 additions & 0 deletions trlx/utils/loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from typing import Callable

# Register load models via module import
from trlx.model import _MODELS
from trlx.model.accelerate_ilql_model import AccelerateILQLModel
from trlx.model.accelerate_ppo_model import AcceleratePPOModel

# Register load orchestrators via module import
from trlx.orchestrator import _ORCH
from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator
from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator

# Register load pipelines via module import
from trlx.pipeline import _DATAPIPELINE
from trlx.pipeline.offline_pipeline import PromptPipeline


def get_model(name: str) -> Callable:
Expand Down

0 comments on commit aafcae9

Please sign in to comment.