@@ -68,32 +68,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
68
68
_check_unpickable_fn (key_value_fn )
69
69
self .key_value_fn = key_value_fn # type: ignore[assignment]
70
70
self ._map = None
71
+ self ._itr = None
72
+ self ._depleted = False
71
73
72
74
def _load_map (self ):
73
- self ._map = {}
74
- for d in self .datapipe :
75
- inp = d if self .key_value_fn is None else self .key_value_fn (d )
75
+ if self ._map is None :
76
+ self ._map = {}
77
+ self ._itr = iter (self .datapipe )
78
+ while not self ._depleted :
76
79
try :
77
- length = len (inp )
78
- except TypeError :
79
- raise TypeError (f"Cannot convert dictionary update element { type (inp )} ({ inp } ) to a sequence" )
80
- if length != 2 :
81
- raise ValueError (f"dictionary update sequence element has length { length } , 2 is required" )
82
- key , value = inp
83
- if key in self ._map :
84
- warnings .warn (f"Found duplicate key { key } . Please check your `key_value_fn`" )
85
- self ._map [key ] = value
80
+ self ._load_next_item ()
81
+ except StopIteration :
82
+ self ._depleted = True
86
83
87
84
def __getitem__ (self , index ):
88
85
try :
89
86
if self ._map is None :
90
- self ._load_map ()
91
- return self ._map [index ] # type: ignore[index]
87
+ self ._map = {}
88
+ self ._itr = iter (self .datapipe )
89
+ raise KeyError
90
+ return self ._map [index ]
92
91
except KeyError :
92
+ while not self ._depleted :
93
+ try :
94
+ key , value = self ._load_next_item ()
95
+ if key == index :
96
+ return value
97
+ except StopIteration :
98
+ self ._depleted = True
93
99
raise IndexError (f"Index { index } is invalid for IterToMapConverter." )
94
100
101
+ def _load_next_item (self ):
102
+ elem = next (self ._itr )
103
+ inp = elem if self .key_value_fn is None else self .key_value_fn (elem )
104
+ try :
105
+ length = len (inp )
106
+ except TypeError :
107
+ raise TypeError (f"Cannot convert dictionary update element { type (inp )} ({ inp } ) to a sequence" )
108
+ if length != 2 :
109
+ raise ValueError (f"dictionary update sequence element has length { length } , 2 is required" )
110
+ key , value = inp
111
+ if key in self ._map :
112
+ warnings .warn (f"Found duplicate key { key } . Please check your `key_value_fn`" )
113
+ self ._map [key ] = value
114
+ return key , value
115
+
95
116
def __len__ (self ):
96
- if self ._map is not None :
117
+ if self ._depleted :
97
118
return len (self ._map ) # type: ignore[arg-type]
98
119
try :
99
120
return len (self .datapipe )
0 commit comments