From a0e08c888e2c0b800b7e28d8615c3ae1fd9bbab9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 17:44:52 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_distributions.py | 3 ++- torchrl/data/replay_buffers/replay_buffers.py | 5 ++--- torchrl/data/replay_buffers/storages.py | 22 ++++++++++++++++++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 8a5b651531e..a69beb7309b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -13,6 +13,7 @@ from _utils_internal import get_default_devices from tensordict import TensorDictBase from torch import autograd, nn +from torch.utils._pytree import tree_map from torchrl.modules import ( NormalParamWrapper, OneHotCategorical, @@ -182,7 +183,7 @@ class TestTruncatedNormal: @pytest.mark.parametrize("device", get_default_devices()) def test_truncnormal(self, min, max, vecs, upscale, shape, device): torch.manual_seed(0) - *vecs, min, max, vecs, upscale = torch.utils._pytree.tree_map( + *vecs, min, max, vecs, upscale = tree_map( lambda t: torch.as_tensor(t, device=device), (*vecs, min, max, vecs, upscale), ) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 2e0eeb80705..aa88dd9d186 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -31,6 +31,7 @@ from tensordict.nn.utils import _set_dispatch_td_nn_modules from tensordict.utils import expand_as_right, expand_right from torch import Tensor +from torch.utils._pytree import tree_map from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation from torchrl.data.replay_buffers.samplers import ( @@ -319,9 +320,7 @@ def dim_extend(self, value): def _transpose(self, data): if is_tensor_collection(data): return data.transpose(self.dim_extend, 0) - return torch.utils._pytree.tree_map( - lambda x: x.transpose(self.dim_extend, 0), data - ) + return tree_map(lambda x: x.transpose(self.dim_extend, 0), data) def _get_collate_fn(self, collate_fn): self._collate_fn = ( diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index cee2a4f7726..70117497642 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1367,16 +1367,36 @@ def _collate_list_tensordict(x): return out +@implement_for("torch", "2.4") def _stack_anything(data): if is_tensor_collection(data[0]): return LazyStackedTensorDict.maybe_dense_stack(data) - return torch.utils._pytree.tree_map( + return tree_map( lambda *x: torch.stack(x), *data, is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x), ) +@implement_for("torch", None, "2.4") +def _stack_anything(data): # noqa: F811 + from tensordict import _pytree + + if not _pytree.PYTREE_REGISTERED_TDS: + raise RuntimeError( + "TensorDict is not registered within PyTree. " + "If you see this error, it means tensordicts instances cannot be natively stacked using tree_map. " + "To solve this issue, (a) upgrade pytorch to a version > 2.4, or (b) make sure TensorDict is registered in PyTree. " + "If this error persists, open an issue on https://github.com/pytorch/rl/issues" + ) + if is_tensor_collection(data[0]): + return LazyStackedTensorDict.maybe_dense_stack(data) + return tree_map( + lambda *x: torch.stack(x), + *data, + ) + + def _collate_id(x): return x