Skip to content

Commit 1fb28ec

Browse files
committed
Ensure structured concurrency exits when all tasks have completed.
1 parent 899292d commit 1fb28ec

File tree

4 files changed

+108
-5
lines changed

4 files changed

+108
-5
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*****************************************************************************
2+
* ------------------------------------------------------------------------- *
3+
* Licensed under the Apache License, Version 2.0 (the "License"); *
4+
* you may not use this file except in compliance with the License. *
5+
* You may obtain a copy of the License at *
6+
* *
7+
* http://www.apache.org/licenses/LICENSE-2.0 *
8+
* *
9+
* Unless required by applicable law or agreed to in writing, software *
10+
* distributed under the License is distributed on an "AS IS" BASIS, *
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
12+
* See the License for the specific language governing permissions and *
13+
* limitations under the License. *
14+
*****************************************************************************/
15+
package com.google.mu.util.concurrent;
16+
17+
import java.util.concurrent.Phaser;
18+
19+
/** Helper to ensure that all started tasks must have run to completion. */
20+
final class Completion implements AutoCloseable {
21+
private final Phaser phaser = new Phaser(1);
22+
23+
void run(Runnable task) {
24+
wrap(task).run();
25+
}
26+
27+
Runnable wrap(Runnable task) {
28+
phaser.register();
29+
return () -> {
30+
try {
31+
task.run();
32+
} finally {
33+
phaser.arrive();
34+
}
35+
};
36+
}
37+
38+
@Override public void close() {
39+
phaser.arriveAndAwaitAdvance();
40+
}
41+
}

mug/src/main/java/com/google/mu/util/concurrent/Fanout.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,18 @@ Scope add(Runnable... tasks) {
481481
}
482482

483483
void run() throws StructuredConcurrencyInterruptedException {
484-
try {
485-
withUnlimitedConcurrency().parallelize(runnables.stream());
484+
try (Completion completion = new Completion()){
485+
withUnlimitedConcurrency().parallelize(runnables.stream().map(completion::wrap));
486486
} catch (InterruptedException e) {
487487
throw new StructuredConcurrencyInterruptedException(e);
488488
}
489489
}
490490

491491
@Deprecated
492492
void runUninterruptibly() {
493-
withUnlimitedConcurrency().parallelizeUninterruptibly(runnables.stream());
493+
try (Completion completion = new Completion()){
494+
withUnlimitedConcurrency().parallelizeUninterruptibly(runnables.stream().map(completion::wrap));
495+
}
494496
}
495497
}
496498

mug/src/main/java/com/google/mu/util/concurrent/Parallelizer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,10 @@ public void parallelizeUninterruptibly(Stream<? extends Runnable> tasks) {
478478
inputs -> {
479479
List<O> outputs = new ArrayList<>(inputs.size());
480480
outputs.addAll(Collections.nCopies(inputs.size(), null));
481-
try {
481+
try (Completion completion = new Completion()){
482482
parallelize(
483483
IntStream.range(0, inputs.size()).boxed(),
484-
i -> outputs.set(i, concurrentFunction.apply(inputs.get(i))));
484+
i -> completion.run(() -> outputs.set(i, concurrentFunction.apply(inputs.get(i)))));
485485
} catch (InterruptedException e) {
486486
throw new StructuredConcurrencyInterruptedException(e);
487487
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package com.google.mu.util.concurrent;
2+
3+
import static com.google.common.truth.Truth.assertThat;
4+
5+
import java.util.concurrent.atomic.AtomicBoolean;
6+
7+
import org.junit.Test;
8+
import org.junit.runner.RunWith;
9+
import org.junit.runners.JUnit4;
10+
11+
@RunWith(JUnit4.class)
12+
public class CompletionTest {
13+
@Test public void noTaskStarted() throws Exception {
14+
try (Completion completion = new Completion()) {}
15+
}
16+
17+
@Test public void singleTask_succeeded() throws Exception {
18+
try (Completion completion = new Completion()) {
19+
completion.run(() -> {});
20+
}
21+
}
22+
23+
@Test public void singleTask_failed() throws Exception {
24+
try {
25+
try (Completion completion = new Completion()) {
26+
completion.run(() -> {
27+
throw new RuntimeException("test");
28+
});
29+
}
30+
} catch (RuntimeException e) {
31+
assertThat(e).hasMessageThat().contains("test");
32+
}
33+
}
34+
35+
@Test public void twoTasks_succeeded() throws Exception {
36+
AtomicBoolean done = new AtomicBoolean();
37+
try (Completion completion = new Completion()) {
38+
completion.run(() -> {});
39+
new Thread(completion.wrap(() -> {
40+
done.set(true);
41+
})).start();
42+
}
43+
assertThat(done.get()).isTrue();
44+
}
45+
46+
@Test public void twoTasks_failed() throws Exception {
47+
AtomicBoolean done = new AtomicBoolean();
48+
try (Completion completion = new Completion()) {
49+
new Thread(completion.wrap(() -> {
50+
done.set(true);
51+
})).start();
52+
completion.run(() -> {
53+
throw new RuntimeException("test");
54+
});
55+
} catch (RuntimeException e) {
56+
assertThat(e).hasMessageThat().contains("test");
57+
}
58+
assertThat(done.get()).isTrue();
59+
}
60+
}

0 commit comments

Comments
 (0)