Skip to content

Commit

Permalink
[Feature] TensorDictMap
Browse files Browse the repository at this point in the history
ghstack-source-id: 57d15444a8c0389ce0ebf0651fb52f6655684018
Pull Request resolved: #2306
  • Loading branch information
vmoens committed Jul 22, 2024
1 parent 770a87d commit fed01f1
Show file tree
Hide file tree
Showing 5 changed files with 517 additions and 13 deletions.
125 changes: 124 additions & 1 deletion test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import importlib.util

import pytest

import torch

from tensordict import TensorDict
from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash
from torchrl.data import LazyTensorStorage, ListStorage
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
)
from torchrl.envs import GymEnv

_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
"gym", None
Expand Down Expand Up @@ -114,6 +123,120 @@ def test_query(self, clone, index_key):
for i in range(1, 3):
assert res[index_key][i].item() != res[index_key][i + 1].item()

def test_query_module(self):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = LazyTensorStorage(23)

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)

tensor_dict_storage[index] = value
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()


class TesttTensorDictMap:
@pytest.mark.parametrize(
"storage_type",
[
functools.partial(ListStorage, 1000),
functools.partial(LazyTensorStorage, 1000),
],
)
def test_map(self, storage_type):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = storage_type()

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)
assert not hasattr(tensor_dict_storage, "out_keys")

tensor_dict_storage[index] = value
if isinstance(embedding_storage, LazyTensorStorage):
assert hasattr(tensor_dict_storage, "out_keys")
else:
assert not hasattr(tensor_dict_storage, "out_keys")
assert tensor_dict_storage._has_lazy_out_keys()
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()

@pytest.mark.skipif(not _has_gym, reason="gym not installed")
def test_map_rollout(self):
torch.manual_seed(0)
env = GymEnv("CartPole-v1")
env.set_seed(0)
rollout = env.rollout(100)
source, dest = rollout.exclude("next"), rollout.get("next")
storage = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation", "action"],
)
storage_indices = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation"],
out_keys=["_index"],
)
# maps the (obs, action) tuple to a corresponding next state
storage[source] = dest
storage_indices[source] = source
contains = storage.contains(source)
assert len(contains) == rollout.shape[-1]
assert contains.all()
contains = storage.contains(torch.cat([source, source + 1]))
assert len(contains) == rollout.shape[-1] * 2
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
10 changes: 9 additions & 1 deletion torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash
from .map import (
BinaryToDecimal,
HashToInt,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
TensorMap,
)
from .postprocs import MultiStep
from .replay_buffers import (
Flat2TED,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
from .query import HashToInt, QueryModule
from .tdstorage import TensorDictMap, TensorMap
Loading

0 comments on commit fed01f1

Please sign in to comment.