Skip to content

Commit 1472157

Browse files
ejguanfacebook-github-bot
authored andcommitted
Fix pin_memory_fn for NamedTuple (#1086)
Summary: Fixes #1085 Per title. And, even though I can add a test, this test won't be executed as we don't have a GPU CI machine yet. I have tested on my local machine though Pull Request resolved: #1086 Reviewed By: NivekT Differential Revision: D44094225 Pulled By: ejguan fbshipit-source-id: 9c8414c31b76c93cee7e31c4e2da14076e9792bf
1 parent f2a1051 commit 1472157

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

test/test_iterdatapipe.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from collections import defaultdict
1515
from functools import partial
16-
from typing import Dict
16+
from typing import Dict, NamedTuple
1717

1818
import expecttest
1919
import torch
@@ -91,6 +91,11 @@ async def _async_x_mul_y(x, y):
9191
return x * y
9292

9393

94+
class NamedTensors(NamedTuple):
95+
x: torch.Tensor
96+
y: torch.Tensor
97+
98+
9499
class TestIterDataPipe(expecttest.TestCase):
95100
def test_in_memory_cache_holder_iterdatapipe(self) -> None:
96101
source_dp = IterableWrapper(range(10))
@@ -1521,6 +1526,10 @@ def test_pin_memory(self):
15211526
dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory()
15221527
self.assertTrue(all(v.is_pinned() for d in dp for v in d.values()))
15231528

1529+
# NamedTuple
1530+
dp = IterableWrapper([NamedTensors(torch.tensor(i), torch.tensor(i + 1)) for i in range(10)]).pin_memory()
1531+
self.assertTrue(all(v.is_pinned() for d in dp for v in d))
1532+
15241533
# Dict of List of Tensors
15251534
dp = (
15261535
IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)])

torchdata/datapipes/utils/pin_memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def pin_memory_fn(data, device=None):
2727
elif isinstance(data, collections.abc.Sequence):
2828
pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment]
2929
try:
30-
type(data)(*pinned_data)
30+
return type(data)(*pinned_data)
3131
except TypeError:
3232
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
3333
return pinned_data

0 commit comments

Comments
 (0)