@@ -1467,11 +1467,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
14671467 status = ThreadStatus::Spinning;
14681468 }
14691469
1470- void SetBlocked (std::function<bool ()> should_block,
1470+ bool SetBlocked (std::function<bool ()> should_block,
14711471 std::function<void()> post_block) {
14721472 std::unique_lock<std::mutex> lk (mutex);
1473- assert (GetStatus () == ThreadStatus::Spinning);
1474- status.store (ThreadStatus::Blocking, std::memory_order_relaxed);
1473+ auto old_status = status.exchange (ThreadStatus::Blocking, std::memory_order_seq_cst);
1474+ if (old_status != ThreadStatus::Spinning) {
1475+ // Encountered a logical error
1476+ return false ;
1477+ }
14751478 if (should_block ()) {
14761479 status.store (ThreadStatus::Blocked, std::memory_order_relaxed);
14771480 do {
@@ -1480,6 +1483,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
14801483 post_block ();
14811484 }
14821485 status.store (ThreadStatus::Spinning, std::memory_order_relaxed);
1486+ return true ;
14831487 }
14841488
14851489 private:
@@ -1558,62 +1562,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
15581562
15591563 // Attempt to block
15601564 if (!t) {
1561- td.SetBlocked ( // Pre-block test
1562- [&]() -> bool {
1563- bool should_block = true ;
1564- // Check whether work was pushed to us while attempting to block. We make
1565- // this test while holding the per-thread status lock, and after setting
1566- // our status to ThreadStatus::Blocking.
1567- //
1568- // This synchronizes with ThreadPool::Schedule which pushes work to the queue
1569- // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake):
1570- //
1571- // Main thread: Worker:
1572- // #1 Push work #A Set status blocking
1573- // #2 Read worker status #B Check queue
1574- // #3 Wake if blocking/blocked
1575- //
1576- // If #A is before #2 then main sees worker blocked and wakes
1577- //
1578- // If #A if after #2 then #B will see #1, and we abandon blocking
1579- assert (!t);
1580- t = q.PopFront ();
1581- if (t) {
1582- should_block = false ;
1583- }
1584-
1585- // No work pushed to us, continue attempting to block. The remaining
1586- // test is to synchronize with termination requests. If we are
1587- // shutting down and all worker threads blocked without work, that's
1588- // we are done.
1589- if (should_block) {
1590- blocked_++;
1591- if (done_ && blocked_ == num_threads_) {
1592- should_block = false ;
1593- // Almost done, but need to re-check queues.
1594- // Consider that all queues are empty and all worker threads are preempted
1595- // right after incrementing blocked_ above. Now a free-standing thread
1596- // submits work and calls destructor (which sets done_). If we don't
1597- // re-check queues, we will exit leaving the work unexecuted.
1598- if (NonEmptyQueueIndex () != -1 ) {
1599- // Note: we must not pop from queues before we decrement blocked_,
1600- // otherwise the following scenario is possible. Consider that instead
1601- // of checking for emptiness we popped the only element from queues.
1602- // Now other worker threads can start exiting, which is bad if the
1603- // work item submits other work. So we just check emptiness here,
1604- // which ensures that all worker threads exit at the same time.
1605- blocked_--;
1606- } else {
1607- should_exit = true ;
1565+ if (!td.SetBlocked ( // Pre-block test
1566+ [&]() -> bool {
1567+ bool should_block = true ;
1568+ // Check whether work was pushed to us while attempting to block. We make
1569+ // this test while holding the per-thread status lock, and after setting
1570+ // our status to ThreadStatus::Blocking.
1571+ //
1572+ // This synchronizes with ThreadPool::Schedule which pushes work to the queue
1573+ // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake):
1574+ //
1575+ // Main thread: Worker:
1576+ // #1 Push work #A Set status blocking
1577+ // #2 Read worker status #B Check queue
1578+ // #3 Wake if blocking/blocked
1579+ //
1580+ // If #A is before #2 then main sees worker blocked and wakes
1581+ //
1582+ // If #A if after #2 then #B will see #1, and we abandon blocking
1583+ assert (!t);
1584+ t = q.PopFront ();
1585+ if (t) {
1586+ should_block = false ;
1587+ }
1588+
1589+ // No work pushed to us, continue attempting to block. The remaining
1590+ // test is to synchronize with termination requests. If we are
1591+ // shutting down and all worker threads blocked without work, that's
1592+ // we are done.
1593+ if (should_block) {
1594+ blocked_++;
1595+ if (done_ && blocked_ == num_threads_) {
1596+ should_block = false ;
1597+ // Almost done, but need to re-check queues.
1598+ // Consider that all queues are empty and all worker threads are preempted
1599+ // right after incrementing blocked_ above. Now a free-standing thread
1600+ // submits work and calls destructor (which sets done_). If we don't
1601+ // re-check queues, we will exit leaving the work unexecuted.
1602+ if (NonEmptyQueueIndex () != -1 ) {
1603+ // Note: we must not pop from queues before we decrement blocked_,
1604+ // otherwise the following scenario is possible. Consider that instead
1605+ // of checking for emptiness we popped the only element from queues.
1606+ // Now other worker threads can start exiting, which is bad if the
1607+ // work item submits other work. So we just check emptiness here,
1608+ // which ensures that all worker threads exit at the same time.
1609+ blocked_--;
1610+ } else {
1611+ should_exit = true ;
1612+ }
1613+ }
16081614 }
1609- }
1610- }
1611- return should_block;
1612- },
1613- // Post-block update (executed only if we blocked)
1614- [&]() {
1615- blocked_--;
1616- });
1615+ return should_block;
1616+ },
1617+ // Post-block update (executed only if we blocked)
1618+ [&]() {
1619+ blocked_--;
1620+ })) {
1621+ // Encountered a fatal logic error in SetBlocked
1622+ should_exit = true ;
1623+ break ;
1624+ }
16171625 // Thread just unblocked. Unless we picked up work while
16181626 // blocking, or are exiting, then either work was pushed to
16191627 // us, or it was pushed to an overloaded queue
0 commit comments