Skip to content

Commit 7d16b7e

Browse files
author
pytorchbot
committed
2025-03-20 nightly release (b54da34)
1 parent 711030a commit 7d16b7e

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

torchdata/nodes/map.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
self._sem,
182182
self._stop,
183183
),
184+
name="read_thread(target=_populate_queue)",
184185
daemon=True,
185186
)
186187
self._workers: List[Union[threading.Thread, mp.Process]] = []
@@ -193,7 +194,12 @@ def __init__(
193194
self._stop if self.method == "thread" else self._mp_stop,
194195
)
195196
self._workers.append(
196-
threading.Thread(target=_apply_udf, args=args, daemon=True)
197+
threading.Thread(
198+
target=_apply_udf,
199+
args=args,
200+
daemon=True,
201+
name=f"worker_thread_{worker_id}(target=_apply_udf)",
202+
)
197203
if self.method == "thread"
198204
else mp_context.Process(target=_apply_udf, args=args, daemon=True)
199205
)
@@ -205,6 +211,7 @@ def __init__(
205211
target=_sort_worker,
206212
args=(self._intermed_q, self._sort_q, self._stop),
207213
daemon=True,
214+
name="sort_thread(target=_sort_worker)",
208215
)
209216
self._out_q = self._sort_q
210217

@@ -231,10 +238,10 @@ def __iter__(self) -> Iterator[T]:
231238

232239
def __next__(self) -> T:
233240
while True:
234-
if self._stop.is_set():
241+
if self._stop.is_set() or self._mp_stop.is_set():
235242
raise StopIteration()
236243
elif self._done and self._sem._value == self._max_tasks:
237-
# Don't stop if we still have items in the queue
244+
# _done is set, and semaphore is back at initial value, so we can stop
238245
self._stop.set()
239246
self._mp_stop.set()
240247
raise StopIteration()
@@ -468,7 +475,7 @@ class _SingleThreadedMapper(Iterator[T]):
468475
Prefetcher and PinMemory.
469476
470477
A thread is started on __init__ and stopped on __del__/_shutdown.
471-
The thread runs _populate_queue, which acquires a BoundedSemaphore with initial value
478+
The thread runs worker, which acquires a BoundedSemaphore with initial value
472479
of `prefetch_factor`.
473480
474481
When next() is called on this iterator, it will block until an item is available on _q.
@@ -478,14 +485,15 @@ class _SingleThreadedMapper(Iterator[T]):
478485
- any other item: return the item
479486
480487
A Bounded semaphore is used to limit concurrency and memory utilization.
481-
If N items have been pulled from the source, and M items have been yielded by this iterator,
482-
we maintain the invariant that semaphore.value + (N - M) == prefetch_factor (modulo
488+
If N items have been pulled from the source (i.e. acquire the semaphore),
489+
and M items have been yielded by this iterator (i.e. release the semaphore),
490+
we maintain the invariant that semaphore.value + (M - N) == prefetch_factor (modulo
483491
non-atomicness of operations).
484492
485-
_populate_queue calls semaphore.acquire. When we pull an item from the queue, we
486-
call semaphore.release (unless it's a StartupExceptionWrapper, because _populate_queue
493+
worker calls semaphore.acquire. When we pull an item from the queue, we
494+
call semaphore.release (unless it's a StartupExceptionWrapper, because worker
487495
does not acquire sempahores in this case). All outstanding items are either being
488-
processed in _populate_queue, in the _q, or about to be returned by an in-flight next() call.
496+
processed in worker, in the _q, or about to be returned by an in-flight next() call.
489497
"""
490498

491499
def __init__(
@@ -526,6 +534,7 @@ def __init__(
526534
self._stop_event,
527535
),
528536
daemon=True,
537+
name=f"worker_thread(target={self.worker.__name__})",
529538
)
530539
self._thread.start()
531540

torchdata/nodes/snapshot_store.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def append(self, snapshot: Any, version: int) -> None:
5858
def pop_version(self, version: int) -> Optional[Any]:
5959
ver, val = None, None
6060
with self._lock:
61+
# pop all items that have a lesser version index
6162
while self._q.queue and version >= self._q.queue[0][0]:
6263
ver, val = self._q.get_nowait()
6364

0 commit comments

Comments
 (0)