Skip to content

Commit 09ffa9e

Browse files
committed
Make iter to map conversion more lazy
1 parent 73d3aa9 commit 09ffa9e

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

test/test_iterdatapipe.py

+12
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,18 @@ def test_itertomap_mapdatapipe(self):
10301030
self.assertEqual(len(wa), 1)
10311031
self.assertRegex(str(wa[0].message), r"Found duplicate key")
10321032

1033+
# More lazily: load only until necessary
1034+
source_dp = IterableWrapper(list(zip(keys, values)))
1035+
lazy_map_dp = source_dp.to_map_datapipe()
1036+
_ = lazy_map_dp["k" + str(4)]
1037+
self.assertEqual(len(lazy_map_dp._map), 5)
1038+
_ = lazy_map_dp["k" + str(7)]
1039+
self.assertEqual(len(lazy_map_dp._map), 8)
1040+
try:
1041+
_ = lazy_map_dp["k" + str(20)]
1042+
except IndexError:
1043+
self.assertEqual(len(lazy_map_dp._map), 10)
1044+
10331045
def test_mux_longest_iterdatapipe(self):
10341046

10351047
# Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted

torchdata/datapipes/iter/util/converter.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -68,32 +68,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
6868
_check_unpickable_fn(key_value_fn)
6969
self.key_value_fn = key_value_fn # type: ignore[assignment]
7070
self._map = None
71+
self._itr = None
72+
self._depleted = False
7173

7274
def _load_map(self):
73-
self._map = {}
74-
for d in self.datapipe:
75-
inp = d if self.key_value_fn is None else self.key_value_fn(d)
75+
if self._map is None:
76+
self._map = {}
77+
self._itr = iter(self.datapipe)
78+
while not self._depleted:
7679
try:
77-
length = len(inp)
78-
except TypeError:
79-
raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence")
80-
if length != 2:
81-
raise ValueError(f"dictionary update sequence element has length {length}, 2 is required")
82-
key, value = inp
83-
if key in self._map:
84-
warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`")
85-
self._map[key] = value
80+
self._load_next_item()
81+
except StopIteration:
82+
self._depleted = True
8683

8784
def __getitem__(self, index):
8885
try:
8986
if self._map is None:
90-
self._load_map()
91-
return self._map[index] # type: ignore[index]
87+
self._map = {}
88+
self._itr = iter(self.datapipe)
89+
raise KeyError
90+
return self._map[index]
9291
except KeyError:
92+
while not self._depleted:
93+
try:
94+
key, value = self._load_next_item()
95+
if key == index:
96+
return value
97+
except StopIteration:
98+
self._depleted = True
9399
raise IndexError(f"Index {index} is invalid for IterToMapConverter.")
94100

101+
def _load_next_item(self):
102+
elem = next(self._itr)
103+
inp = elem if self.key_value_fn is None else self.key_value_fn(elem)
104+
try:
105+
length = len(inp)
106+
except TypeError:
107+
raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence")
108+
if length != 2:
109+
raise ValueError(f"dictionary update sequence element has length {length}, 2 is required")
110+
key, value = inp
111+
if key in self._map:
112+
warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`")
113+
self._map[key] = value
114+
return key, value
115+
95116
def __len__(self):
96-
if self._map is not None:
117+
if self._depleted:
97118
return len(self._map) # type: ignore[arg-type]
98119
try:
99120
return len(self.datapipe)

0 commit comments

Comments
 (0)