Skip to content

Commit

Permalink
improve dict/Mapping check in copy_data_to_device (#958)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #958

Reviewed By: diego-urgell

Differential Revision: D67962833

fbshipit-source-id: 347f8fd222b96f582f8c7a4e780e29750057b885
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 9, 2025
1 parent 6e6824c commit 52b5568
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
17 changes: 16 additions & 1 deletion tests/utils/test_device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import dataclasses
import os
import unittest
from collections import defaultdict, namedtuple
from collections import defaultdict, namedtuple, UserDict
from dataclasses import dataclass
from typing import Any, Dict
from unittest import mock
Expand Down Expand Up @@ -104,6 +104,21 @@ def test_copy_data_to_device_dict(self) -> None:
for key in new_dict.keys():
self.assertEqual(new_dict[key].device.type, "cuda")

@skip_if_not_gpu
def test_copy_data_to_device_mapping(self) -> None:
cuda_0 = torch.device("cuda:0")
f = torch.tensor([1, 2, 3])
g = torch.tensor([4, 5, 6])

# Use UserDict instead of a regular dictionary
original_dict = UserDict({"f": f, "g": g})

self.assertEqual(f.device.type, "cpu")
self.assertEqual(g.device.type, "cpu")
new_dict = copy_data_to_device(original_dict, cuda_0)
for key in new_dict.keys():
self.assertEqual(new_dict[key].device.type, "cuda")

@skip_if_not_gpu
def test_copy_data_to_device_named_tuple(self) -> None:
cuda_0 = torch.device("cuda:0")
Expand Down
8 changes: 7 additions & 1 deletion torchtnt/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,16 @@ def copy_data_to_device(
for k, v in data.items()
},
)
elif issubclass(data_type, dict):
elif (
hasattr(data, "items")
and hasattr(data, "__getitem__")
and hasattr(data, "__iter__")
):
# pyre-ignore: Too many arguments [19]: Call `object.__init__` expects 0 positional arguments, 1
return data_type(
{
k: copy_data_to_device(v, device, *args, **kwargs)
# pyre-ignore: Undefined attribute [16]: `Variable[T]` has no attribute `items`.
for k, v in data.items()
}
)
Expand Down

0 comments on commit 52b5568

Please sign in to comment.