Skip to content
This repository was archived by the owner on Apr 29, 2024. It is now read-only.

Commit 0852a26

Browse files
authored
Added data shuffling (#35)
1 parent 4fcf864 commit 0852a26

File tree

1 file changed

+29
-3
lines changed
  • kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk

1 file changed

+29
-3
lines changed

kilroy_module_pytorch_py_sdk/src/kilroy_module_pytorch_py_sdk/utils.py

+29-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
from aiostream.aiter_utils import AsyncIteratorContext
2020
from aiostream.stream import iterate
21+
from kilroy_server_py_utils.utils import batchify, background
2122
from torch import Tensor, nn
2223
from torch.nn.utils.rnn import (
2324
PackedSequence,
@@ -27,7 +28,6 @@
2728
)
2829

2930
from kilroy_module_pytorch_py_sdk.models.abc import SequentialModel
30-
from kilroy_server_py_utils.utils import batchify, background
3131

3232
T = TypeVar("T")
3333

@@ -138,19 +138,22 @@ class CachingAsyncIterable(AsyncIterable[T], Generic[T]):
138138
_iterator: AsyncIterator[T]
139139
_cache: MutableMapping[str, T]
140140
_prefix: str
141+
_shuffle: bool
141142
_watermark: int
143+
_length: Optional[int]
142144
_lock: Lock
143145

144146
def __init__(
145147
self,
146148
iterable: AsyncIterable[T],
147149
cache: Optional[MutableMapping[str, T]] = None,
148150
prefix: Optional[str] = None,
151+
shuffle: bool = True,
149152
):
150153
self._ctx = iterate(iterable).stream()
151154
self._cache = cache if cache is not None else {}
152155
self._prefix = prefix if prefix is not None else uuid4().hex
153-
self._watermark = 0
156+
self._shuffle = shuffle
154157

155158
def _make_key(self, i: int) -> str:
156159
return f"{self._prefix}-{i}"
@@ -169,17 +172,39 @@ async def _get_at(self, i: int) -> T:
169172

170173
return self._cache[self._make_key(self._watermark - 1)]
171174

172-
async def __aiter__(self) -> AsyncIterator[T]:
175+
async def _iter_cached(self) -> AsyncIterator[T]:
176+
indices = (
177+
torch.randperm(self._length).tolist()
178+
if self._shuffle
179+
else range(self._length)
180+
)
181+
182+
for i in indices:
183+
yield await self._get_at(i)
184+
185+
async def _iter_uncached(self) -> AsyncIterator[T]:
173186
i = 0
174187
while True:
175188
try:
176189
yield await self._get_at(i)
177190
except StopAsyncIteration:
191+
async with self._lock:
192+
self._length = i
178193
return
179194
i += 1
180195

196+
async def __aiter__(self) -> AsyncIterator[T]:
197+
async with self._lock:
198+
is_cached = self._length is not None
199+
200+
ait = self._iter_cached() if is_cached else self._iter_uncached()
201+
async for x in ait:
202+
yield x
203+
181204
async def __aenter__(self) -> "CachingAsyncIterable[T]":
182205
self._lock = Lock()
206+
self._watermark = 0
207+
self._length = None
183208
await self._ctx.__aenter__()
184209
self._iterator = self._ctx.__aiter__()
185210
return self
@@ -194,5 +219,6 @@ async def __aexit__(
194219
key = self._make_key(i)
195220
del self._cache[key]
196221
self._watermark = 0
222+
self._length = None
197223
await self._ctx.__aexit__(exc_type, exc, traceback)
198224
return None

0 commit comments

Comments
 (0)