Skip to content

Commit e24af06

Browse files
committed
Introduce Task monad
This commit introduces Task, a data structure that represents a recipe, or program, for producing a value of type T (or failing with an exception). It is similar in semantics to RunnableGraph[T], but intended as first-class building block. It has the following properties: - A task can have resources associated to it, which are guaranteed to be released if the task is cancelled or fails - Tasks can be forked so multiple ones can run concurrently - Such forked tasks can be cancelled A Task can be created from a RunnableGraph which has a KillSwitch, by connecting a Source and a Sink through a KillSwitch, or by direct lambda functions.
1 parent 9b1f823 commit e24af06

File tree

5 files changed

+515
-2
lines changed

5 files changed

+515
-2
lines changed

actor/src/main/scala/org/apache/pekko/japi/function/Function.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,17 @@ import scala.annotation.nowarn
2525
@nowarn("msg=@SerialVersionUID has no effect")
2626
@SerialVersionUID(1L)
2727
@FunctionalInterface
28-
trait Function[-T, +R] extends java.io.Serializable {
28+
trait Function[-T, +R] extends java.io.Serializable { outer =>
2929
@throws(classOf[Exception])
3030
def apply(param: T): R
31+
32+
/** Returns a function that applies [fn] to the result of this function. */
33+
def andThen[U](fn: Function[R, U]): Function[T, U] = new Function[T,U] {
34+
override def apply(param: T) = fn(outer.apply(param))
35+
}
36+
37+
/** Returns a Scala function representation for this function. */
38+
def toScala[T1 <: T, R1 >: R]: T1 => R1 = t => apply(t)
3139
}
3240

3341
object Function {
@@ -98,11 +106,19 @@ trait Predicate[-T] extends java.io.Serializable {
98106
@nowarn("msg=@SerialVersionUID has no effect")
99107
@SerialVersionUID(1L)
100108
@FunctionalInterface
101-
trait Creator[+T] extends Serializable {
109+
trait Creator[+T] extends Serializable { outer =>
102110

103111
/**
104112
* This method must return a different instance upon every call.
105113
*/
106114
@throws(classOf[Exception])
107115
def create(): T
116+
117+
/** Returns a function that applies [fn] to the result of this function. */
118+
def andThen[U](fn: Function[T, U]): Creator[U] = new Creator[U] {
119+
override def create() = fn(outer.create())
120+
}
121+
122+
/** Returns a Scala function representation for this function. */
123+
def toScala[T1 >: T]: () => T1 = () => create()
108124
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.task.javadsl;
19+
20+
import org.apache.pekko.stream.StreamTest;
21+
import org.apache.pekko.testkit.PekkoJUnitActorSystemResource;
22+
import org.apache.pekko.testkit.PekkoSpec;
23+
import org.apache.pekko.stream.Materializer;
24+
import org.apache.pekko.Done;
25+
26+
import org.junit.ClassRule;
27+
import org.junit.Test;
28+
29+
import org.apache.pekko.task.Runtime;
30+
import org.apache.pekko.japi.function.Creator;
31+
import org.apache.pekko.stream.javadsl.Sink;
32+
import org.apache.pekko.stream.javadsl.Source;
33+
34+
import java.util.concurrent.TimeUnit;
35+
import java.util.concurrent.ExecutionException;
36+
import java.util.concurrent.atomic.AtomicLong;
37+
38+
import java.util.Optional;
39+
import java.time.Duration;
40+
41+
import static org.junit.Assert.assertEquals;
42+
import static org.junit.Assert.assertTrue;
43+
44+
public class TaskTest extends StreamTest{
45+
private final Runtime runtime = new Runtime(Materializer.createMaterializer(system));
46+
47+
public TaskTest() {
48+
super(actorSystemResource);
49+
}
50+
51+
@ClassRule
52+
public static PekkoJUnitActorSystemResource actorSystemResource =
53+
new PekkoJUnitActorSystemResource("TaskTest", PekkoSpec.testConf());
54+
55+
private <T> T run(Task<T> task) throws Exception {
56+
return runtime.runAsync(task).get(2, TimeUnit.SECONDS);
57+
}
58+
59+
@Test
60+
public void can_run_task_from_lambda() throws Exception {
61+
assertEquals("Hello", run(Task.run(() -> "Hello")));
62+
}
63+
64+
@Test
65+
public void can_map() throws Exception {
66+
assertEquals(25, run(Task.run(() -> "25").map(Integer::parseInt)).intValue());
67+
}
68+
69+
@Test
70+
public void can_flatMap_to_run() throws Exception {
71+
assertEquals(25, run(Task.run(() -> "25").flatMap(s -> Task.run(() -> Integer.parseInt(s)))).intValue());
72+
}
73+
74+
@Test
75+
public void can_zipPar_two_tasks() throws Exception {
76+
Task<String> task = Task.run(() -> {
77+
Thread.sleep(100);
78+
return "Hello";
79+
});
80+
long start = System.currentTimeMillis();
81+
assertEquals("HelloHello", run(task.zipPar(task, (s1,s2) -> s1 + s2)));
82+
long end = (System.currentTimeMillis() - start);
83+
// FIXME there's probably a less flaky to test this later on.
84+
assertTrue((end - start) < 150);
85+
}
86+
87+
@Test
88+
public void can_cancel_forked_task() throws Exception {
89+
AtomicLong check = new AtomicLong();
90+
Task<Long> task = Task.run(() -> Thread.sleep(500)).map(d -> check.incrementAndGet());
91+
run(task.fork().flatMap(fiber ->
92+
fiber.cancel().map(cancelled ->
93+
"cancelled"
94+
)
95+
));
96+
assertEquals(0, check.get());
97+
}
98+
99+
@Test(expected=ExecutionException.class)
100+
public void joining_cancelled_fiber_yields_exception() throws Exception {
101+
Task<Long> task = Task.run(() -> Thread.sleep(500)).map(d -> 42L);
102+
run(task.fork().flatMap(fiber ->
103+
fiber.cancel().flatMap(cancelled ->
104+
fiber.join()
105+
)
106+
));
107+
}
108+
109+
@Test
110+
public void can_run_graph() throws Exception {
111+
assertEquals(Optional.of("hello"),
112+
run(Task.connectCancellable(Source.single("hello"), Sink.headOption()).flatMap(fiber -> fiber.join())));
113+
}
114+
115+
@Test
116+
public void can_cancel_graph() throws Exception {
117+
assertEquals(Done.getInstance(),
118+
run(Task.connectCancellable(Source.tick(Duration.ofSeconds(1), Duration.ofSeconds(1), "hello"), Sink.headOption()).flatMap(fiber -> fiber.cancel())));
119+
120+
}
121+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.task.javadsl;
19+
20+
import org.apache.pekko.task.FiberImpl;
21+
import org.apache.pekko.task.JoinDef;
22+
import org.apache.pekko.task.CancelDef;
23+
import org.apache.pekko.Done;
24+
25+
public class Fiber<T> {
26+
private final FiberImpl<T> impl;
27+
28+
public Fiber(FiberImpl<T> impl) {
29+
this.impl = impl;
30+
}
31+
32+
public Task<T> join() {
33+
return new Task<>(new JoinDef<>(impl));
34+
}
35+
36+
public Task<Done> cancel() {
37+
return new Task<>(new CancelDef<>(impl)).asDone();
38+
}
39+
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.task.javadsl;
19+
20+
import java.util.concurrent.CompletionStage;
21+
22+
import org.apache.pekko.Done;
23+
import org.apache.pekko.japi.function.Creator;
24+
import org.apache.pekko.japi.function.Effect;
25+
import org.apache.pekko.japi.function.Function2;
26+
import org.apache.pekko.japi.function.Function;
27+
import org.apache.pekko.stream.KillSwitch;
28+
import org.apache.pekko.stream.KillSwitches;
29+
import org.apache.pekko.stream.javadsl.Keep;
30+
import org.apache.pekko.stream.javadsl.Sink;
31+
import org.apache.pekko.stream.javadsl.Source;
32+
import org.apache.pekko.task.FiberImpl;
33+
import org.apache.pekko.task.FlatMapDef;
34+
import org.apache.pekko.task.ForkDef;
35+
import org.apache.pekko.task.GraphDef;
36+
import org.apache.pekko.task.MapDef;
37+
import org.apache.pekko.task.TaskDef;
38+
import org.apache.pekko.task.ValueDef;
39+
40+
import scala.Tuple2;
41+
import scala.util.Failure;
42+
import scala.util.Success;
43+
import scala.util.Try;
44+
45+
import static scala.jdk.javaapi.FutureConverters.*;
46+
47+
public class Task<T> extends org.apache.pekko.task.Task<T> {
48+
public static <T> Task<T> succeed(T value) {
49+
return run(() -> value);
50+
}
51+
52+
public static <T> Task<T> run(Creator<T> fn) {
53+
return new Task<>(new ValueDef<>(fn.andThen(t -> new Success<>(t)).toScala()));
54+
}
55+
56+
public static Task<Done> run(Effect fn) {
57+
return run(() -> {
58+
fn.apply();
59+
return Done.getInstance();
60+
});
61+
}
62+
63+
/** Returns a Task that connects the given source to a KillSwitch, and then through the given sink. */
64+
public static <A,T> Task<Fiber<T>> connectCancellable(Source<A, ?> source, Sink<A, ? extends CompletionStage<T>> sink) {
65+
return connect(source.viaMat(KillSwitches.single(), Keep.right()), sink);
66+
}
67+
68+
/** Returns a Task that runs the given cancellable source through the given sink. */
69+
public static <A,T> Task<Fiber<T>> connect(Source<A, ? extends KillSwitch> source, Sink<A, ? extends CompletionStage<T>> sink) {
70+
Task<FiberImpl<T>> res = new Task<FiberImpl<T>>(new GraphDef<T>(source.toMat(sink, (killswitch, cs) -> Tuple2.apply(killswitch, asScala(cs)))));
71+
72+
return res.map(Fiber::new);
73+
}
74+
75+
private final TaskDef<T> definition;
76+
77+
Task(TaskDef<T> definition) {
78+
this.definition = definition;
79+
}
80+
81+
public TaskDef<T> definition() {
82+
return definition;
83+
}
84+
85+
/** Returns a task that maps this task's value through the given function */
86+
public <U> Task<U> map(Function<? super T, ? extends U> fn) {
87+
return new Task<U>(new MapDef<T,U>(definition, res -> {
88+
if (res.isFailure()) {
89+
return new Failure<>(res.failed().get());
90+
} else {
91+
return tryOf(() -> fn.apply(res.get()));
92+
}
93+
}));
94+
}
95+
96+
public <U> Task<U> as(U value) {
97+
return map(t -> value);
98+
}
99+
100+
public Task<Done> asDone() {
101+
return as(Done.getInstance());
102+
}
103+
104+
/** Returns a task that maps this task's value through the given function, and runs the resulting task after this one. */
105+
public <U> Task<U> flatMap(Function<? super T, Task<? extends U>> fn) {
106+
return new Task<U>(new FlatMapDef<T,U>(definition, res -> {
107+
if (res.isFailure()) {
108+
return new ValueDef<>(() -> new Failure<>(res.failed().get()));
109+
} else {
110+
try {
111+
return TaskDef.narrow(fn.apply(res.get()).definition());
112+
} catch (Exception x) {
113+
return new ValueDef<>(() -> new Failure<>(x));
114+
}
115+
}
116+
}));
117+
}
118+
119+
public Task<Fiber<T>> fork() {
120+
return new Task<>(new ForkDef<>(definition)).map(Fiber::new);
121+
}
122+
123+
public <U,R> Task<R> zipPar(Task<? extends U> that, Function2<? super T, ? super U, ? extends R> combine) {
124+
return that.fork().flatMap(fiber ->
125+
this.flatMap(t ->
126+
fiber.join().map(u ->
127+
combine.apply(t,u)
128+
)
129+
)
130+
);
131+
}
132+
133+
private static <T> Try<T> tryOf(Creator<T> fn) {
134+
return Try.apply(fn.toScala());
135+
}
136+
}

0 commit comments

Comments
 (0)