From ce4840dc5aa80a5c907bf6a30c3baccce5a5d786 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 7 Jan 2024 11:04:49 +0900 Subject: [PATCH] Add separator token --- d3rlpy/algos/gato/base.py | 22 ++++++++++++++++++-- d3rlpy/algos/gato/dataset.py | 31 +++++++++++++++++++++------- d3rlpy/algos/gato/gato.py | 10 ++++++++- d3rlpy/algos/gato/torch/gato_impl.py | 14 +++++++++++-- d3rlpy/models/torch/embeddings.py | 28 ++++++++++++++++++++++++- reproductions/offline/gato.py | 2 +- 6 files changed, 93 insertions(+), 14 deletions(-) diff --git a/d3rlpy/algos/gato/base.py b/d3rlpy/algos/gato/base.py index 49705aa8..6eb548d4 100644 --- a/d3rlpy/algos/gato/base.py +++ b/d3rlpy/algos/gato/base.py @@ -27,7 +27,7 @@ ) from ...metrics import evaluate_gato_with_environment from ...models import EmbeddingModuleFactory, TokenEmbeddingFactory -from ...models.torch import TokenEmbedding +from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding from ...serializable_config import generate_dict_config_field from ...torch_utility import eval_api, train_api from ...types import GymEnv, NDArray, Observation @@ -73,6 +73,11 @@ def inner_update( def token_embeddings(self) -> Dict[str, TokenEmbedding]: pass + @property + @abstractmethod + def separator_token_embedding(self) -> SeparatorTokenEmbedding: + pass + @dataclasses.dataclass() class GatoBaseConfig(LearnableConfig): @@ -154,6 +159,8 @@ def predict(self, x: Observation) -> Union[NDArray, int]: self._append_observation_embedding(embedding[i], position) position += 1 + self._append_separator_embedding() + action_token_embedding = token_embeddings[self._action_embedding_key] action_values = [] for i in range(self._action_token_size): @@ -213,6 +220,13 @@ def _append_action_embedding(self, embedding: torch.Tensor) -> None: self._observation_masks.append(0) self._action_masks.append(1) + def _append_separator_embedding(self) -> None: + assert self._algo.impl + self._embeddings.append(self._algo.impl.separator_token_embedding.data) + self._observation_positions.append(0) + self._observation_masks.append(0) + self._action_masks.append(0) + def reset(self) -> None: """Clears stateful information.""" self._embeddings.clear() @@ -325,7 +339,11 @@ def fit( self._impl.wrap_models_by_ddp() # create GatoReplayBuffer - replay_buffer = GatoReplayBuffer(datasets, self._impl.token_embeddings) + replay_buffer = GatoReplayBuffer( + replay_buffers=datasets, + token_embeddings=self._impl.token_embeddings, + separator_token_embedding=self._impl.separator_token_embedding, + ) # save hyperparameters save_config(self, logger) diff --git a/d3rlpy/algos/gato/dataset.py b/d3rlpy/algos/gato/dataset.py index 86184088..50b4437b 100644 --- a/d3rlpy/algos/gato/dataset.py +++ b/d3rlpy/algos/gato/dataset.py @@ -11,7 +11,7 @@ ReplayBufferBase, slice_observations, ) -from ...models.torch import TokenEmbedding +from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding from ...torch_utility import torch_batch_pad_array from ...types import NDArray, Observation @@ -42,7 +42,7 @@ def from_episode( token_embeddings: Dict[str, TokenEmbedding], task_id: str, ) -> "GatoTokenEpisode": - block_size = 0 + block_size = 1 # +1 for separator token observations: List[NDArray] if isinstance(observation_to_embedding_keys, str): @@ -127,7 +127,7 @@ def create_embedding_metadata( [ observation_positions, torch.zeros( - (num_steps, action_token_size), + (num_steps, action_token_size + 1), # +1 for separator device=device, dtype=torch.int32, ), @@ -143,7 +143,7 @@ def create_embedding_metadata( dtype=torch.float32, ), torch.zeros( - (num_steps, action_token_size, 1), + (num_steps, action_token_size + 1, 1), # +1 for separator device=device, dtype=torch.float32, ), @@ -154,7 +154,7 @@ def create_embedding_metadata( action_masks = torch.cat( [ torch.zeros( - (num_steps, observation_token_size, 1), + (num_steps, observation_token_size + 1, 1), # +1 for separator device=device, dtype=torch.float32, ), @@ -199,6 +199,7 @@ def __call__( end_step: int, token_size: int, token_embeddings: Dict[str, TokenEmbedding], + separator_token_embedding: SeparatorTokenEmbedding, ) -> GatoTrainingInputEmbedding: num_steps = token_size // episode.one_step_block_size end_step = end_step + 1 @@ -230,11 +231,20 @@ def __call__( dtype=torch.int32, ) + # create separator token + # (T, 1, N) + separator_embeddings = separator_token_embedding(action_embedding) + # concat observations and actions # S = total_obs_num_tokens + action_num_tokens # (T, S, N) concat_embeddings = torch.cat( - [concat_observation_embeddings, action_embedding], dim=1 + [ + concat_observation_embeddings, + separator_embeddings, + action_embedding, + ], + dim=1, ) metadata = create_embedding_metadata( observation_embeddings=concat_observation_embeddings, @@ -244,7 +254,10 @@ def __call__( action_tokens = torch.cat( [ torch.zeros( - (actual_num_steps, observation_token_size), + ( + actual_num_steps, + observation_token_size + 1, + ), # +1 for separator device=device, dtype=torch.int32, ), @@ -348,14 +361,17 @@ class GatoReplayBuffer: _episodes_per_task: DefaultDict[str, List[GatoTokenEpisode]] _token_slicer: GatoTokenSlicer _token_embeddings: Dict[str, TokenEmbedding] + _separator_token_embedding: SeparatorTokenEmbedding def __init__( self, replay_buffers: Sequence[ReplayBufferWithEmbeddingKeys], token_embeddings: Dict[str, TokenEmbedding], + separator_token_embedding: SeparatorTokenEmbedding, ): self._token_slicer = GatoTokenSlicer() self._token_embeddings = token_embeddings + self._separator_token_embedding = separator_token_embedding self._episodes = [] self._episodes_per_task = defaultdict(list) for replay_buffer in replay_buffers: @@ -383,6 +399,7 @@ def sample_embedding_sequence( end_step=end_step, token_size=length, token_embeddings=self._token_embeddings, + separator_token_embedding=self._separator_token_embedding, ) def sample_embedding_mini_batch( diff --git a/d3rlpy/algos/gato/gato.py b/d3rlpy/algos/gato/gato.py index 1c42ecee..4503e9b4 100644 --- a/d3rlpy/algos/gato/gato.py +++ b/d3rlpy/algos/gato/gato.py @@ -6,6 +6,7 @@ from ...constants import ActionSpace from ...models.builders import create_gato_transformer from ...models.optimizers import OptimizerFactory, make_optimizer_field +from ...models.torch import SeparatorTokenEmbedding from ...types import Shape from .base import GatoAlgoBase, GatoBaseConfig from .torch import GatoImpl, GatoModules @@ -74,15 +75,22 @@ def inner_create_impl( for key, factory in self._config.token_embeddings.items() } + # create separator token embedding + separator_token_embedding = SeparatorTokenEmbedding( + self._config.layer_width + ) + optim = self._config.optim_factory.create( list(transformer.named_modules()) - + list(embedding_modules.named_modules()), + + list(embedding_modules.named_modules()) + + list(separator_token_embedding.named_modules()), lr=self._config.initial_learning_rate, ) modules = GatoModules( transformer=transformer, embedding_modules=embedding_modules, + separator_token_embedding=separator_token_embedding, optim=optim, ) diff --git a/d3rlpy/algos/gato/torch/gato_impl.py b/d3rlpy/algos/gato/torch/gato_impl.py index 0bb57a0a..62d325d5 100644 --- a/d3rlpy/algos/gato/torch/gato_impl.py +++ b/d3rlpy/algos/gato/torch/gato_impl.py @@ -8,7 +8,11 @@ from torch import nn from torch.optim import Optimizer -from ....models.torch import GatoTransformer, TokenEmbedding +from ....models.torch import ( + GatoTransformer, + SeparatorTokenEmbedding, + TokenEmbedding, +) from ....torch_utility import Modules from ..base import GatoAlgoImplBase from ..dataset import GatoEmbeddingMiniBatch, GatoInputEmbedding @@ -20,6 +24,7 @@ class GatoModules(Modules): transformer: GatoTransformer embedding_modules: nn.ModuleDict + separator_token_embedding: SeparatorTokenEmbedding optim: Optimizer @@ -74,7 +79,8 @@ def inner_update( torch.nn.utils.clip_grad_norm_( list(self._modules.transformer.parameters()) - + list(self._modules.embedding_modules.parameters()), + + list(self._modules.embedding_modules.parameters()) + + list(self._modules.separator_token_embedding.parameters()), self._clip_grad_norm, ) @@ -118,3 +124,7 @@ def compute_loss(self, batch: GatoEmbeddingMiniBatch) -> torch.Tensor: @property def token_embeddings(self) -> Dict[str, TokenEmbedding]: return self._token_embeddings + + @property + def separator_token_embedding(self) -> SeparatorTokenEmbedding: + return self._modules.separator_token_embedding diff --git a/d3rlpy/models/torch/embeddings.py b/d3rlpy/models/torch/embeddings.py index 131f11bf..d74e4673 100644 --- a/d3rlpy/models/torch/embeddings.py +++ b/d3rlpy/models/torch/embeddings.py @@ -4,8 +4,13 @@ from ...tokenizers import Tokenizer from ...types import Int32NDArray, NDArray +from .parameters import Parameter -__all__ = ["TokenEmbedding", "TokenEmbeddingWithTokenizer"] +__all__ = [ + "TokenEmbedding", + "TokenEmbeddingWithTokenizer", + "SeparatorTokenEmbedding", +] class TokenEmbedding(nn.Module): # type: ignore @@ -46,3 +51,24 @@ def get_tokens(self, x: NDArray) -> Int32NDArray: def decode(self, x: Int32NDArray) -> NDArray: return self._tokenizer.decode(x) + + +class SeparatorTokenEmbedding(nn.Module): # type: ignore + _data: Parameter + + def __init__(self, embed_size: int): + super().__init__() + self._data = Parameter(torch.zeros(embed_size, dtype=torch.float32)) + + def __call__(self, x: torch.Tensor) -> torch.Tensor: + return super().__call__(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.ndim == 3 + assert x.shape[-1] == self._data.data.shape[0] + data = self._data.data.view(1, 1, -1) + return torch.tile(data, [x.shape[0], 1, 1]) + + @property + def data(self) -> torch.Tensor: + return self._data.data diff --git a/reproductions/offline/gato.py b/reproductions/offline/gato.py index 0edc3df2..1ee23454 100644 --- a/reproductions/offline/gato.py +++ b/reproductions/offline/gato.py @@ -50,7 +50,7 @@ def main() -> None: maximum_learning_rate=1e-4, warmup_steps=15000, final_steps=100000, - optim_factory=d3rlpy.models.AdamWFactory( + optim_factory=d3rlpy.models.GPTAdamWFactory( weight_decay=0.1, betas=(0.9, 0.95) ), action_vocab_size=1024,