Skip to content

Commit 0f96b2d

Browse files
committed
Merge remote-tracking branch 'origin/main' into fs-eire/dl2
2 parents c579317 + 9115682 commit 0f96b2d

File tree

10 files changed

+331
-216
lines changed

10 files changed

+331
-216
lines changed

include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,11 +1467,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
14671467
status = ThreadStatus::Spinning;
14681468
}
14691469

1470-
void SetBlocked(std::function<bool()> should_block,
1470+
bool SetBlocked(std::function<bool()> should_block,
14711471
std::function<void()> post_block) {
14721472
std::unique_lock<std::mutex> lk(mutex);
1473-
assert(GetStatus() == ThreadStatus::Spinning);
1474-
status.store(ThreadStatus::Blocking, std::memory_order_relaxed);
1473+
auto old_status = status.exchange(ThreadStatus::Blocking, std::memory_order_seq_cst);
1474+
if (old_status != ThreadStatus::Spinning) {
1475+
// Encountered a logical error
1476+
return false;
1477+
}
14751478
if (should_block()) {
14761479
status.store(ThreadStatus::Blocked, std::memory_order_relaxed);
14771480
do {
@@ -1480,6 +1483,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
14801483
post_block();
14811484
}
14821485
status.store(ThreadStatus::Spinning, std::memory_order_relaxed);
1486+
return true;
14831487
}
14841488

14851489
private:
@@ -1558,62 +1562,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
15581562

15591563
// Attempt to block
15601564
if (!t) {
1561-
td.SetBlocked( // Pre-block test
1562-
[&]() -> bool {
1563-
bool should_block = true;
1564-
// Check whether work was pushed to us while attempting to block. We make
1565-
// this test while holding the per-thread status lock, and after setting
1566-
// our status to ThreadStatus::Blocking.
1567-
//
1568-
// This synchronizes with ThreadPool::Schedule which pushes work to the queue
1569-
// and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake):
1570-
//
1571-
// Main thread: Worker:
1572-
// #1 Push work #A Set status blocking
1573-
// #2 Read worker status #B Check queue
1574-
// #3 Wake if blocking/blocked
1575-
//
1576-
// If #A is before #2 then main sees worker blocked and wakes
1577-
//
1578-
// If #A if after #2 then #B will see #1, and we abandon blocking
1579-
assert(!t);
1580-
t = q.PopFront();
1581-
if (t) {
1582-
should_block = false;
1583-
}
1584-
1585-
// No work pushed to us, continue attempting to block. The remaining
1586-
// test is to synchronize with termination requests. If we are
1587-
// shutting down and all worker threads blocked without work, that's
1588-
// we are done.
1589-
if (should_block) {
1590-
blocked_++;
1591-
if (done_ && blocked_ == num_threads_) {
1592-
should_block = false;
1593-
// Almost done, but need to re-check queues.
1594-
// Consider that all queues are empty and all worker threads are preempted
1595-
// right after incrementing blocked_ above. Now a free-standing thread
1596-
// submits work and calls destructor (which sets done_). If we don't
1597-
// re-check queues, we will exit leaving the work unexecuted.
1598-
if (NonEmptyQueueIndex() != -1) {
1599-
// Note: we must not pop from queues before we decrement blocked_,
1600-
// otherwise the following scenario is possible. Consider that instead
1601-
// of checking for emptiness we popped the only element from queues.
1602-
// Now other worker threads can start exiting, which is bad if the
1603-
// work item submits other work. So we just check emptiness here,
1604-
// which ensures that all worker threads exit at the same time.
1605-
blocked_--;
1606-
} else {
1607-
should_exit = true;
1565+
if (!td.SetBlocked( // Pre-block test
1566+
[&]() -> bool {
1567+
bool should_block = true;
1568+
// Check whether work was pushed to us while attempting to block. We make
1569+
// this test while holding the per-thread status lock, and after setting
1570+
// our status to ThreadStatus::Blocking.
1571+
//
1572+
// This synchronizes with ThreadPool::Schedule which pushes work to the queue
1573+
// and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake):
1574+
//
1575+
// Main thread: Worker:
1576+
// #1 Push work #A Set status blocking
1577+
// #2 Read worker status #B Check queue
1578+
// #3 Wake if blocking/blocked
1579+
//
1580+
// If #A is before #2 then main sees worker blocked and wakes
1581+
//
1582+
// If #A if after #2 then #B will see #1, and we abandon blocking
1583+
assert(!t);
1584+
t = q.PopFront();
1585+
if (t) {
1586+
should_block = false;
1587+
}
1588+
1589+
// No work pushed to us, continue attempting to block. The remaining
1590+
// test is to synchronize with termination requests. If we are
1591+
// shutting down and all worker threads blocked without work, that's
1592+
// we are done.
1593+
if (should_block) {
1594+
blocked_++;
1595+
if (done_ && blocked_ == num_threads_) {
1596+
should_block = false;
1597+
// Almost done, but need to re-check queues.
1598+
// Consider that all queues are empty and all worker threads are preempted
1599+
// right after incrementing blocked_ above. Now a free-standing thread
1600+
// submits work and calls destructor (which sets done_). If we don't
1601+
// re-check queues, we will exit leaving the work unexecuted.
1602+
if (NonEmptyQueueIndex() != -1) {
1603+
// Note: we must not pop from queues before we decrement blocked_,
1604+
// otherwise the following scenario is possible. Consider that instead
1605+
// of checking for emptiness we popped the only element from queues.
1606+
// Now other worker threads can start exiting, which is bad if the
1607+
// work item submits other work. So we just check emptiness here,
1608+
// which ensures that all worker threads exit at the same time.
1609+
blocked_--;
1610+
} else {
1611+
should_exit = true;
1612+
}
1613+
}
16081614
}
1609-
}
1610-
}
1611-
return should_block;
1612-
},
1613-
// Post-block update (executed only if we blocked)
1614-
[&]() {
1615-
blocked_--;
1616-
});
1615+
return should_block;
1616+
},
1617+
// Post-block update (executed only if we blocked)
1618+
[&]() {
1619+
blocked_--;
1620+
})) {
1621+
// Encountered a fatal logic error in SetBlocked
1622+
should_exit = true;
1623+
break;
1624+
}
16171625
// Thread just unblocked. Unless we picked up work while
16181626
// blocking, or are exiting, then either work was pushed to
16191627
// us, or it was pushed to an overloaded queue

js/web/test/data/ops/conv.jsonc

Lines changed: 80 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -391,48 +391,48 @@
391391
}
392392
]
393393
},
394-
{
395-
"name": "conv - vectorize group - B",
396-
"operator": "Conv",
397-
"inputShapeDefinitions": "rankOnly",
398-
"opset": { "domain": "", "version": 17 },
399-
"attributes": [
400-
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
401-
{ "name": "group", "data": 3, "type": "int" }
402-
],
403-
"cases": [
404-
{
405-
"name": "T[0]",
406-
"inputs": [
407-
{
408-
"data": [
409-
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
410-
19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
411-
],
412-
"dims": [1, 3, 3, 3],
413-
"type": "float32"
414-
},
415-
{
416-
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
417-
"dims": [3, 1, 2, 2],
418-
"type": "float32"
419-
},
420-
{
421-
"data": [0.1, 0.2, 0.3],
422-
"dims": [3],
423-
"type": "float32"
424-
}
425-
],
426-
"outputs": [
427-
{
428-
"data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3],
429-
"dims": [1, 3, 2, 2],
430-
"type": "float32"
431-
}
432-
]
433-
}
434-
]
435-
},
394+
// {
395+
// "name": "conv - vectorize group - B",
396+
// "operator": "Conv",
397+
// "inputShapeDefinitions": "rankOnly",
398+
// "opset": { "domain": "", "version": 17 },
399+
// "attributes": [
400+
// { "name": "kernel_shape", "data": [2, 2], "type": "ints" },
401+
// { "name": "group", "data": 3, "type": "int" }
402+
// ],
403+
// "cases": [
404+
// {
405+
// "name": "T[0]",
406+
// "inputs": [
407+
// {
408+
// "data": [
409+
// 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
410+
// 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
411+
// ],
412+
// "dims": [1, 3, 3, 3],
413+
// "type": "float32"
414+
// },
415+
// {
416+
// "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
417+
// "dims": [3, 1, 2, 2],
418+
// "type": "float32"
419+
// },
420+
// {
421+
// "data": [0.1, 0.2, 0.3],
422+
// "dims": [3],
423+
// "type": "float32"
424+
// }
425+
// ],
426+
// "outputs": [
427+
// {
428+
// "data": [27.1, 37.1, 57.1, 67.1, 293.2, 319.2, 371.2, 397.2, 847.3, 889.3, 409.3, 428.3],
429+
// "dims": [1, 3, 2, 2],
430+
// "type": "float32"
431+
// }
432+
// ]
433+
// }
434+
// ]
435+
// },
436436
{
437437
"name": "conv - vectorize group - C",
438438
"operator": "Conv",
@@ -470,44 +470,44 @@
470470
}
471471
]
472472
},
473-
{
474-
"name": "conv - vectorize group - D",
475-
"operator": "Conv",
476-
"inputShapeDefinitions": "rankOnly",
477-
"opset": { "domain": "", "version": 17 },
478-
"attributes": [
479-
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
480-
{ "name": "group", "data": 3, "type": "int" },
481-
{ "name": "strides", "data": [2, 2], "type": "ints" }
482-
],
483-
"cases": [
484-
{
485-
"name": "T[0] strides = [2, 2]",
486-
"inputs": [
487-
{
488-
"data": [
489-
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
490-
19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
491-
],
492-
"dims": [1, 3, 3, 4],
493-
"type": "float32"
494-
},
495-
{
496-
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
497-
"dims": [3, 1, 2, 2],
498-
"type": "float32"
499-
}
500-
],
501-
"outputs": [
502-
{
503-
"data": [34, 54, 386, 438, 1122, 1206],
504-
"dims": [1, 3, 1, 2],
505-
"type": "float32"
506-
}
507-
]
508-
}
509-
]
510-
},
473+
// {
474+
// "name": "conv - vectorize group - D",
475+
// "operator": "Conv",
476+
// "inputShapeDefinitions": "rankOnly",
477+
// "opset": { "domain": "", "version": 17 },
478+
// "attributes": [
479+
// { "name": "kernel_shape", "data": [2, 2], "type": "ints" },
480+
// { "name": "group", "data": 3, "type": "int" },
481+
// { "name": "strides", "data": [2, 2], "type": "ints" }
482+
// ],
483+
// "cases": [
484+
// {
485+
// "name": "T[0] strides = [2, 2]",
486+
// "inputs": [
487+
// {
488+
// "data": [
489+
// 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
490+
// 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
491+
// ],
492+
// "dims": [1, 3, 3, 4],
493+
// "type": "float32"
494+
// },
495+
// {
496+
// "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
497+
// "dims": [3, 1, 2, 2],
498+
// "type": "float32"
499+
// }
500+
// ],
501+
// "outputs": [
502+
// {
503+
// "data": [34, 54, 386, 438, 1122, 1206],
504+
// "dims": [1, 3, 1, 2],
505+
// "type": "float32"
506+
// }
507+
// ]
508+
// }
509+
// ]
510+
// },
511511
{
512512
"name": "conv - pointwise",
513513
"operator": "Conv",

0 commit comments

Comments
 (0)