From 5091e16878038f9d028d616a23ac487867dd79f6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 25 Feb 2025 21:42:30 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/_reductions.py | 6 +++++- tensordict/base.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 42bda8c3e..15013b59d 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -91,7 +91,11 @@ def from_metadata(metadata=metadata, prefix=None): _ = metadata.pop("size", None) d = { - key: NonTensorData(data, batch_size=batch_size, device=device) + key: NonTensorData( + data, + batch_size=batch_size, + device=torch.device(device) if device is not None else None, + ) for (key, (data, batch_size, device)) in non_tensor.items() } for key, (dtype, local_shape, start, stop, pad) in leaves.items(): diff --git a/tensordict/base.py b/tensordict/base.py index 4389ed1ec..428ab3545 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5056,7 +5056,7 @@ def assign( metadata_dict["non_tensors"][key] = ( value.data, list(value.batch_size), - value.device, + str(value.device) if value.device is not None else None, ) return elif _is_tensor_collection(cls):