Skip to content

Commit

Permalink
[Feature] TensorDictMap
Browse files Browse the repository at this point in the history
ghstack-source-id: 4cb945e4ba9036b0dcc8373d45a3546bb0cf384a
Pull Request resolved: #2306
  • Loading branch information
vmoens committed Oct 15, 2024
1 parent 4736fac commit d894358
Show file tree
Hide file tree
Showing 6 changed files with 538 additions and 13 deletions.
20 changes: 20 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,26 @@ The following classes are deprecated and just point to the classes above:
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec

Trees and Forests
-----------------

TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently.

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryToDecimal
HashToInt
QueryModule
RandomProjectionHash
SipHash
TensorDictMap
TensorMap


Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------

Expand Down
125 changes: 124 additions & 1 deletion test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import importlib.util

import pytest

import torch

from tensordict import TensorDict
from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash
from torchrl.data import LazyTensorStorage, ListStorage
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
)
from torchrl.envs import GymEnv

_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
"gym", None
Expand Down Expand Up @@ -114,6 +123,120 @@ def test_query(self, clone, index_key):
for i in range(1, 3):
assert res[index_key][i].item() != res[index_key][i + 1].item()

def test_query_module(self):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = LazyTensorStorage(23)

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)

tensor_dict_storage[index] = value
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()


class TesttTensorDictMap:
@pytest.mark.parametrize(
"storage_type",
[
functools.partial(ListStorage, 1000),
functools.partial(LazyTensorStorage, 1000),
],
)
def test_map(self, storage_type):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = storage_type()

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)
assert not hasattr(tensor_dict_storage, "out_keys")

tensor_dict_storage[index] = value
if isinstance(embedding_storage, LazyTensorStorage):
assert hasattr(tensor_dict_storage, "out_keys")
else:
assert not hasattr(tensor_dict_storage, "out_keys")
assert tensor_dict_storage._has_lazy_out_keys()
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()

@pytest.mark.skipif(not _has_gym, reason="gym not installed")
def test_map_rollout(self):
torch.manual_seed(0)
env = GymEnv("CartPole-v1")
env.set_seed(0)
rollout = env.rollout(100)
source, dest = rollout.exclude("next"), rollout.get("next")
storage = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation", "action"],
)
storage_indices = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation"],
out_keys=["_index"],
)
# maps the (obs, action) tuple to a corresponding next state
storage[source] = dest
storage_indices[source] = source
contains = storage.contains(source)
assert len(contains) == rollout.shape[-1]
assert contains.all()
contains = storage.contains(torch.cat([source, source + 1]))
assert len(contains) == rollout.shape[-1] * 2
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
10 changes: 9 additions & 1 deletion torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash
from .map import (
BinaryToDecimal,
HashToInt,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
TensorMap,
)
from .postprocs import MultiStep
from .replay_buffers import (
Flat2TED,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
from .query import HashToInt, QueryModule
from .tdstorage import TensorDictMap, TensorMap
Loading

0 comments on commit d894358

Please sign in to comment.