Skip to content

Commit

Permalink
Add separator token
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 7, 2024
1 parent eec89f1 commit ce4840d
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 14 deletions.
22 changes: 20 additions & 2 deletions d3rlpy/algos/gato/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from ...metrics import evaluate_gato_with_environment
from ...models import EmbeddingModuleFactory, TokenEmbeddingFactory
from ...models.torch import TokenEmbedding
from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding
from ...serializable_config import generate_dict_config_field
from ...torch_utility import eval_api, train_api
from ...types import GymEnv, NDArray, Observation
Expand Down Expand Up @@ -73,6 +73,11 @@ def inner_update(
def token_embeddings(self) -> Dict[str, TokenEmbedding]:
pass

@property
@abstractmethod
def separator_token_embedding(self) -> SeparatorTokenEmbedding:
pass


@dataclasses.dataclass()
class GatoBaseConfig(LearnableConfig):
Expand Down Expand Up @@ -154,6 +159,8 @@ def predict(self, x: Observation) -> Union[NDArray, int]:
self._append_observation_embedding(embedding[i], position)
position += 1

self._append_separator_embedding()

action_token_embedding = token_embeddings[self._action_embedding_key]
action_values = []
for i in range(self._action_token_size):
Expand Down Expand Up @@ -213,6 +220,13 @@ def _append_action_embedding(self, embedding: torch.Tensor) -> None:
self._observation_masks.append(0)
self._action_masks.append(1)

def _append_separator_embedding(self) -> None:
assert self._algo.impl
self._embeddings.append(self._algo.impl.separator_token_embedding.data)
self._observation_positions.append(0)
self._observation_masks.append(0)
self._action_masks.append(0)

def reset(self) -> None:
"""Clears stateful information."""
self._embeddings.clear()
Expand Down Expand Up @@ -325,7 +339,11 @@ def fit(
self._impl.wrap_models_by_ddp()

# create GatoReplayBuffer
replay_buffer = GatoReplayBuffer(datasets, self._impl.token_embeddings)
replay_buffer = GatoReplayBuffer(
replay_buffers=datasets,
token_embeddings=self._impl.token_embeddings,
separator_token_embedding=self._impl.separator_token_embedding,
)

# save hyperparameters
save_config(self, logger)
Expand Down
31 changes: 24 additions & 7 deletions d3rlpy/algos/gato/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ReplayBufferBase,
slice_observations,
)
from ...models.torch import TokenEmbedding
from ...models.torch import SeparatorTokenEmbedding, TokenEmbedding
from ...torch_utility import torch_batch_pad_array
from ...types import NDArray, Observation

Expand Down Expand Up @@ -42,7 +42,7 @@ def from_episode(
token_embeddings: Dict[str, TokenEmbedding],
task_id: str,
) -> "GatoTokenEpisode":
block_size = 0
block_size = 1 # +1 for separator token

observations: List[NDArray]
if isinstance(observation_to_embedding_keys, str):
Expand Down Expand Up @@ -127,7 +127,7 @@ def create_embedding_metadata(
[
observation_positions,
torch.zeros(
(num_steps, action_token_size),
(num_steps, action_token_size + 1), # +1 for separator
device=device,
dtype=torch.int32,
),
Expand All @@ -143,7 +143,7 @@ def create_embedding_metadata(
dtype=torch.float32,
),
torch.zeros(
(num_steps, action_token_size, 1),
(num_steps, action_token_size + 1, 1), # +1 for separator
device=device,
dtype=torch.float32,
),
Expand All @@ -154,7 +154,7 @@ def create_embedding_metadata(
action_masks = torch.cat(
[
torch.zeros(
(num_steps, observation_token_size, 1),
(num_steps, observation_token_size + 1, 1), # +1 for separator
device=device,
dtype=torch.float32,
),
Expand Down Expand Up @@ -199,6 +199,7 @@ def __call__(
end_step: int,
token_size: int,
token_embeddings: Dict[str, TokenEmbedding],
separator_token_embedding: SeparatorTokenEmbedding,
) -> GatoTrainingInputEmbedding:
num_steps = token_size // episode.one_step_block_size
end_step = end_step + 1
Expand Down Expand Up @@ -230,11 +231,20 @@ def __call__(
dtype=torch.int32,
)

# create separator token
# (T, 1, N)
separator_embeddings = separator_token_embedding(action_embedding)

# concat observations and actions
# S = total_obs_num_tokens + action_num_tokens
# (T, S, N)
concat_embeddings = torch.cat(
[concat_observation_embeddings, action_embedding], dim=1
[
concat_observation_embeddings,
separator_embeddings,
action_embedding,
],
dim=1,
)
metadata = create_embedding_metadata(
observation_embeddings=concat_observation_embeddings,
Expand All @@ -244,7 +254,10 @@ def __call__(
action_tokens = torch.cat(
[
torch.zeros(
(actual_num_steps, observation_token_size),
(
actual_num_steps,
observation_token_size + 1,
), # +1 for separator
device=device,
dtype=torch.int32,
),
Expand Down Expand Up @@ -348,14 +361,17 @@ class GatoReplayBuffer:
_episodes_per_task: DefaultDict[str, List[GatoTokenEpisode]]
_token_slicer: GatoTokenSlicer
_token_embeddings: Dict[str, TokenEmbedding]
_separator_token_embedding: SeparatorTokenEmbedding

def __init__(
self,
replay_buffers: Sequence[ReplayBufferWithEmbeddingKeys],
token_embeddings: Dict[str, TokenEmbedding],
separator_token_embedding: SeparatorTokenEmbedding,
):
self._token_slicer = GatoTokenSlicer()
self._token_embeddings = token_embeddings
self._separator_token_embedding = separator_token_embedding
self._episodes = []
self._episodes_per_task = defaultdict(list)
for replay_buffer in replay_buffers:
Expand Down Expand Up @@ -383,6 +399,7 @@ def sample_embedding_sequence(
end_step=end_step,
token_size=length,
token_embeddings=self._token_embeddings,
separator_token_embedding=self._separator_token_embedding,
)

def sample_embedding_mini_batch(
Expand Down
10 changes: 9 additions & 1 deletion d3rlpy/algos/gato/gato.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...constants import ActionSpace
from ...models.builders import create_gato_transformer
from ...models.optimizers import OptimizerFactory, make_optimizer_field
from ...models.torch import SeparatorTokenEmbedding
from ...types import Shape
from .base import GatoAlgoBase, GatoBaseConfig
from .torch import GatoImpl, GatoModules
Expand Down Expand Up @@ -74,15 +75,22 @@ def inner_create_impl(
for key, factory in self._config.token_embeddings.items()
}

# create separator token embedding
separator_token_embedding = SeparatorTokenEmbedding(
self._config.layer_width
)

optim = self._config.optim_factory.create(
list(transformer.named_modules())
+ list(embedding_modules.named_modules()),
+ list(embedding_modules.named_modules())
+ list(separator_token_embedding.named_modules()),
lr=self._config.initial_learning_rate,
)

modules = GatoModules(
transformer=transformer,
embedding_modules=embedding_modules,
separator_token_embedding=separator_token_embedding,
optim=optim,
)

Expand Down
14 changes: 12 additions & 2 deletions d3rlpy/algos/gato/torch/gato_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from torch import nn
from torch.optim import Optimizer

from ....models.torch import GatoTransformer, TokenEmbedding
from ....models.torch import (
GatoTransformer,
SeparatorTokenEmbedding,
TokenEmbedding,
)
from ....torch_utility import Modules
from ..base import GatoAlgoImplBase
from ..dataset import GatoEmbeddingMiniBatch, GatoInputEmbedding
Expand All @@ -20,6 +24,7 @@
class GatoModules(Modules):
transformer: GatoTransformer
embedding_modules: nn.ModuleDict
separator_token_embedding: SeparatorTokenEmbedding
optim: Optimizer


Expand Down Expand Up @@ -74,7 +79,8 @@ def inner_update(

torch.nn.utils.clip_grad_norm_(
list(self._modules.transformer.parameters())
+ list(self._modules.embedding_modules.parameters()),
+ list(self._modules.embedding_modules.parameters())
+ list(self._modules.separator_token_embedding.parameters()),
self._clip_grad_norm,
)

Expand Down Expand Up @@ -118,3 +124,7 @@ def compute_loss(self, batch: GatoEmbeddingMiniBatch) -> torch.Tensor:
@property
def token_embeddings(self) -> Dict[str, TokenEmbedding]:
return self._token_embeddings

@property
def separator_token_embedding(self) -> SeparatorTokenEmbedding:
return self._modules.separator_token_embedding
28 changes: 27 additions & 1 deletion d3rlpy/models/torch/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@

from ...tokenizers import Tokenizer
from ...types import Int32NDArray, NDArray
from .parameters import Parameter

__all__ = ["TokenEmbedding", "TokenEmbeddingWithTokenizer"]
__all__ = [
"TokenEmbedding",
"TokenEmbeddingWithTokenizer",
"SeparatorTokenEmbedding",
]


class TokenEmbedding(nn.Module): # type: ignore
Expand Down Expand Up @@ -46,3 +51,24 @@ def get_tokens(self, x: NDArray) -> Int32NDArray:

def decode(self, x: Int32NDArray) -> NDArray:
return self._tokenizer.decode(x)


class SeparatorTokenEmbedding(nn.Module): # type: ignore
_data: Parameter

def __init__(self, embed_size: int):
super().__init__()
self._data = Parameter(torch.zeros(embed_size, dtype=torch.float32))

def __call__(self, x: torch.Tensor) -> torch.Tensor:
return super().__call__(x)

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.ndim == 3
assert x.shape[-1] == self._data.data.shape[0]
data = self._data.data.view(1, 1, -1)
return torch.tile(data, [x.shape[0], 1, 1])

@property
def data(self) -> torch.Tensor:
return self._data.data
2 changes: 1 addition & 1 deletion reproductions/offline/gato.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main() -> None:
maximum_learning_rate=1e-4,
warmup_steps=15000,
final_steps=100000,
optim_factory=d3rlpy.models.AdamWFactory(
optim_factory=d3rlpy.models.GPTAdamWFactory(
weight_decay=0.1, betas=(0.9, 0.95)
),
action_vocab_size=1024,
Expand Down

0 comments on commit ce4840d

Please sign in to comment.