From 606ff72c70951197f31da1f4909ed8174942f397 Mon Sep 17 00:00:00 2001 From: SvenDS9 Date: Thu, 16 Feb 2023 11:25:23 +0100 Subject: [PATCH] Fix state methods and actually fix mypy issues --- torchdata/datapipes/iter/util/converter.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py index a01fc1d78..be84c574d 100644 --- a/torchdata/datapipes/iter/util/converter.py +++ b/torchdata/datapipes/iter/util/converter.py @@ -73,9 +73,6 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No self._depleted = False def _load_map(self): - if self._map is None: - self._map = {} - self._itr = iter(self.datapipe) while not self._depleted: try: self._load_next_item() @@ -84,10 +81,7 @@ def _load_map(self): def __getitem__(self, index): try: - if self._map is None: - self._map = {} - self._itr = iter(self.datapipe) - else: + if self._map is not None: return self._map[index] except KeyError: pass @@ -101,7 +95,10 @@ def __getitem__(self, index): raise IndexError(f"Index {index} is invalid for IterToMapConverter.") def _load_next_item(self): - elem = next(self._itr) + if self._map is None: + self._map = {} + self._itr = iter(self.datapipe) + elem = next(self._itr) # type: ignore[arg-type] inp = elem if self.key_value_fn is None else self.key_value_fn(elem) try: length = len(inp) @@ -135,14 +132,10 @@ def __getstate__(self): dill_key_value_fn = dill.dumps(self.key_value_fn) else: dill_key_value_fn = self.key_value_fn - return ( - self.datapipe, - dill_key_value_fn, - self._map, - ) + return (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) def __setstate__(self, state): - (self.datapipe, dill_key_value_fn, self._map) = state + (self.datapipe, dill_key_value_fn, self._map, self._itr, self._depleted) = state if DILL_AVAILABLE: self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] else: