Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 18, 2024
1 parent a0e08c8 commit 4b688cb
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,10 +1391,18 @@ def _stack_anything(data): # noqa: F811
)
if is_tensor_collection(data[0]):
return LazyStackedTensorDict.maybe_dense_stack(data)
return tree_map(
lambda *x: torch.stack(x),
*data,
)
flat_trees = []
spec = None
for d in data:
flat_tree, spec = tree_flatten(d)
flat_trees.append(flat_tree)

leaves = []
for leaf in zip(*flat_trees):
leaf = torch.stack(leaf)
leaves.append(leaf)

return tree_unflatten(leaves, spec)


def _collate_id(x):
Expand Down

0 comments on commit 4b688cb

Please sign in to comment.