Skip to content

Commit

Permalink
Make iter to map conversion more lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenDS9 committed Feb 15, 2023
1 parent 73d3aa9 commit 09ffa9e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
12 changes: 12 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,18 @@ def test_itertomap_mapdatapipe(self):
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Found duplicate key")

# More lazily: load only until necessary
source_dp = IterableWrapper(list(zip(keys, values)))
lazy_map_dp = source_dp.to_map_datapipe()
_ = lazy_map_dp["k" + str(4)]
self.assertEqual(len(lazy_map_dp._map), 5)
_ = lazy_map_dp["k" + str(7)]
self.assertEqual(len(lazy_map_dp._map), 8)
try:
_ = lazy_map_dp["k" + str(20)]
except IndexError:
self.assertEqual(len(lazy_map_dp._map), 10)

def test_mux_longest_iterdatapipe(self):

# Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
Expand Down
51 changes: 36 additions & 15 deletions torchdata/datapipes/iter/util/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,32 +68,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
_check_unpickable_fn(key_value_fn)
self.key_value_fn = key_value_fn # type: ignore[assignment]
self._map = None
self._itr = None
self._depleted = False

def _load_map(self):
self._map = {}
for d in self.datapipe:
inp = d if self.key_value_fn is None else self.key_value_fn(d)
if self._map is None:
self._map = {}
self._itr = iter(self.datapipe)
while not self._depleted:
try:
length = len(inp)
except TypeError:
raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence")
if length != 2:
raise ValueError(f"dictionary update sequence element has length {length}, 2 is required")
key, value = inp
if key in self._map:
warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`")
self._map[key] = value
self._load_next_item()
except StopIteration:
self._depleted = True

def __getitem__(self, index):
try:
if self._map is None:
self._load_map()
return self._map[index] # type: ignore[index]
self._map = {}
self._itr = iter(self.datapipe)
raise KeyError
return self._map[index]
except KeyError:
while not self._depleted:
try:
key, value = self._load_next_item()
if key == index:
return value
except StopIteration:
self._depleted = True
raise IndexError(f"Index {index} is invalid for IterToMapConverter.")

def _load_next_item(self):
elem = next(self._itr)
inp = elem if self.key_value_fn is None else self.key_value_fn(elem)
try:
length = len(inp)
except TypeError:
raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence")
if length != 2:
raise ValueError(f"dictionary update sequence element has length {length}, 2 is required")
key, value = inp
if key in self._map:
warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`")
self._map[key] = value
return key, value

def __len__(self):
if self._map is not None:
if self._depleted:
return len(self._map) # type: ignore[arg-type]
try:
return len(self.datapipe)
Expand Down

0 comments on commit 09ffa9e

Please sign in to comment.