diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 1816143ed..42bda8c3e 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -91,8 +91,8 @@ def from_metadata(metadata=metadata, prefix=None): _ = metadata.pop("size", None) d = { - key: NonTensorData(data, batch_size=batch_size) - for (key, (data, batch_size)) in non_tensor.items() + key: NonTensorData(data, batch_size=batch_size, device=device) + for (key, (data, batch_size, device)) in non_tensor.items() } for key, (dtype, local_shape, start, stop, pad) in leaves.items(): dtype = _STRDTYPE2DTYPE[dtype] diff --git a/tensordict/base.py b/tensordict/base.py index 6f3b1926c..4389ed1ec 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5056,6 +5056,7 @@ def assign( metadata_dict["non_tensors"][key] = ( value.data, list(value.batch_size), + value.device, ) return elif _is_tensor_collection(cls):