Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 17, 2024
1 parent a356152 commit a0e08c8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
3 changes: 2 additions & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from _utils_internal import get_default_devices
from tensordict import TensorDictBase
from torch import autograd, nn
from torch.utils._pytree import tree_map
from torchrl.modules import (
NormalParamWrapper,
OneHotCategorical,
Expand Down Expand Up @@ -182,7 +183,7 @@ class TestTruncatedNormal:
@pytest.mark.parametrize("device", get_default_devices())
def test_truncnormal(self, min, max, vecs, upscale, shape, device):
torch.manual_seed(0)
*vecs, min, max, vecs, upscale = torch.utils._pytree.tree_map(
*vecs, min, max, vecs, upscale = tree_map(
lambda t: torch.as_tensor(t, device=device),
(*vecs, min, max, vecs, upscale),
)
Expand Down
5 changes: 2 additions & 3 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensordict.nn.utils import _set_dispatch_td_nn_modules
from tensordict.utils import expand_as_right, expand_right
from torch import Tensor
from torch.utils._pytree import tree_map

from torchrl._utils import _make_ordinal_device, accept_remote_rref_udf_invocation
from torchrl.data.replay_buffers.samplers import (
Expand Down Expand Up @@ -319,9 +320,7 @@ def dim_extend(self, value):
def _transpose(self, data):
if is_tensor_collection(data):
return data.transpose(self.dim_extend, 0)
return torch.utils._pytree.tree_map(
lambda x: x.transpose(self.dim_extend, 0), data
)
return tree_map(lambda x: x.transpose(self.dim_extend, 0), data)

def _get_collate_fn(self, collate_fn):
self._collate_fn = (
Expand Down
22 changes: 21 additions & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,16 +1367,36 @@ def _collate_list_tensordict(x):
return out


@implement_for("torch", "2.4")
def _stack_anything(data):
if is_tensor_collection(data[0]):
return LazyStackedTensorDict.maybe_dense_stack(data)
return torch.utils._pytree.tree_map(
return tree_map(
lambda *x: torch.stack(x),
*data,
is_leaf=lambda x: isinstance(x, torch.Tensor) or is_tensor_collection(x),
)


@implement_for("torch", None, "2.4")
def _stack_anything(data): # noqa: F811
from tensordict import _pytree

if not _pytree.PYTREE_REGISTERED_TDS:
raise RuntimeError(
"TensorDict is not registered within PyTree. "
"If you see this error, it means tensordicts instances cannot be natively stacked using tree_map. "
"To solve this issue, (a) upgrade pytorch to a version > 2.4, or (b) make sure TensorDict is registered in PyTree. "
"If this error persists, open an issue on https://github.com/pytorch/rl/issues"
)
if is_tensor_collection(data[0]):
return LazyStackedTensorDict.maybe_dense_stack(data)
return tree_map(
lambda *x: torch.stack(x),
*data,
)


def _collate_id(x):
return x

Expand Down

0 comments on commit a0e08c8

Please sign in to comment.