Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 25, 2025
1 parent 1706004 commit 5091e16
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ def from_metadata(metadata=metadata, prefix=None):
_ = metadata.pop("size", None)

d = {
key: NonTensorData(data, batch_size=batch_size, device=device)
key: NonTensorData(
data,
batch_size=batch_size,
device=torch.device(device) if device is not None else None,
)
for (key, (data, batch_size, device)) in non_tensor.items()
}
for key, (dtype, local_shape, start, stop, pad) in leaves.items():
Expand Down
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5056,7 +5056,7 @@ def assign(
metadata_dict["non_tensors"][key] = (
value.data,
list(value.batch_size),
value.device,
str(value.device) if value.device is not None else None,
)
return
elif _is_tensor_collection(cls):
Expand Down

0 comments on commit 5091e16

Please sign in to comment.