18
18
import torch
19
19
from aiostream .aiter_utils import AsyncIteratorContext
20
20
from aiostream .stream import iterate
21
+ from kilroy_server_py_utils .utils import batchify , background
21
22
from torch import Tensor , nn
22
23
from torch .nn .utils .rnn import (
23
24
PackedSequence ,
27
28
)
28
29
29
30
from kilroy_module_pytorch_py_sdk .models .abc import SequentialModel
30
- from kilroy_server_py_utils .utils import batchify , background
31
31
32
32
T = TypeVar ("T" )
33
33
@@ -138,19 +138,22 @@ class CachingAsyncIterable(AsyncIterable[T], Generic[T]):
138
138
_iterator : AsyncIterator [T ]
139
139
_cache : MutableMapping [str , T ]
140
140
_prefix : str
141
+ _shuffle : bool
141
142
_watermark : int
143
+ _length : Optional [int ]
142
144
_lock : Lock
143
145
144
146
def __init__ (
145
147
self ,
146
148
iterable : AsyncIterable [T ],
147
149
cache : Optional [MutableMapping [str , T ]] = None ,
148
150
prefix : Optional [str ] = None ,
151
+ shuffle : bool = True ,
149
152
):
150
153
self ._ctx = iterate (iterable ).stream ()
151
154
self ._cache = cache if cache is not None else {}
152
155
self ._prefix = prefix if prefix is not None else uuid4 ().hex
153
- self ._watermark = 0
156
+ self ._shuffle = shuffle
154
157
155
158
def _make_key (self , i : int ) -> str :
156
159
return f"{ self ._prefix } -{ i } "
@@ -169,17 +172,39 @@ async def _get_at(self, i: int) -> T:
169
172
170
173
return self ._cache [self ._make_key (self ._watermark - 1 )]
171
174
172
- async def __aiter__ (self ) -> AsyncIterator [T ]:
175
+ async def _iter_cached (self ) -> AsyncIterator [T ]:
176
+ indices = (
177
+ torch .randperm (self ._length ).tolist ()
178
+ if self ._shuffle
179
+ else range (self ._length )
180
+ )
181
+
182
+ for i in indices :
183
+ yield await self ._get_at (i )
184
+
185
+ async def _iter_uncached (self ) -> AsyncIterator [T ]:
173
186
i = 0
174
187
while True :
175
188
try :
176
189
yield await self ._get_at (i )
177
190
except StopAsyncIteration :
191
+ async with self ._lock :
192
+ self ._length = i
178
193
return
179
194
i += 1
180
195
196
+ async def __aiter__ (self ) -> AsyncIterator [T ]:
197
+ async with self ._lock :
198
+ is_cached = self ._length is not None
199
+
200
+ ait = self ._iter_cached () if is_cached else self ._iter_uncached ()
201
+ async for x in ait :
202
+ yield x
203
+
181
204
async def __aenter__ (self ) -> "CachingAsyncIterable[T]" :
182
205
self ._lock = Lock ()
206
+ self ._watermark = 0
207
+ self ._length = None
183
208
await self ._ctx .__aenter__ ()
184
209
self ._iterator = self ._ctx .__aiter__ ()
185
210
return self
@@ -194,5 +219,6 @@ async def __aexit__(
194
219
key = self ._make_key (i )
195
220
del self ._cache [key ]
196
221
self ._watermark = 0
222
+ self ._length = None
197
223
await self ._ctx .__aexit__ (exc_type , exc , traceback )
198
224
return None
0 commit comments