Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 25, 2025
1 parent 2aa9170 commit 1706004
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def from_metadata(metadata=metadata, prefix=None):
_ = metadata.pop("size", None)

d = {
key: NonTensorData(data, batch_size=batch_size)
for (key, (data, batch_size)) in non_tensor.items()
key: NonTensorData(data, batch_size=batch_size, device=device)
for (key, (data, batch_size, device)) in non_tensor.items()
}
for key, (dtype, local_shape, start, stop, pad) in leaves.items():
dtype = _STRDTYPE2DTYPE[dtype]
Expand Down
1 change: 1 addition & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5056,6 +5056,7 @@ def assign(
metadata_dict["non_tensors"][key] = (
value.data,
list(value.batch_size),
value.device,
)
return
elif _is_tensor_collection(cls):
Expand Down

0 comments on commit 1706004

Please sign in to comment.