diff --git a/core/src/main/scalajvm/scalapb/zio_grpc/client/StreamingClientCallListener.scala b/core/src/main/scalajvm/scalapb/zio_grpc/client/StreamingClientCallListener.scala index 9b884f23..a9b23cef 100644 --- a/core/src/main/scalajvm/scalapb/zio_grpc/client/StreamingClientCallListener.scala +++ b/core/src/main/scalajvm/scalapb/zio_grpc/client/StreamingClientCallListener.scala @@ -9,34 +9,33 @@ class StreamingClientCallListener[Res]( prefetch: Option[Int], runtime: Runtime[Any], call: ZClientCall[?, Res], - queue: Queue[ResponseFrame[Res]], - buffered: Ref[Int] + queue: Queue[ResponseFrame[Res]] ) extends ClientCall.Listener[Res] { - private val increment = if (prefetch.isDefined) buffered.update(_ + 1) else ZIO.unit - private val fetchOne = if (prefetch.isDefined) ZIO.unit else call.request(1) - private val fetchMore = prefetch match { - case None => ZIO.unit - case Some(n) => buffered.get.flatMap(b => call.request(n - b).when(n > b)) - } + private val fetchOne = + ZIO.whenDiscard(prefetch.isEmpty)(call.request(1)) + + private def fetchMore(n: Int) = + ZIO.whenDiscard(prefetch.isDefined)(call.request(n)) private def unsafeRun(task: IO[Any, Unit]): Unit = Unsafe.unsafe(implicit u => runtime.unsafe.run(task).getOrThrowFiberFailure()) - private def handle(promise: Promise[StatusException, Unit])( - chunk: Chunk[ResponseFrame[Res]] - ) = (chunk.lastOption match { - case Some(ResponseFrame.Trailers(status, trailers)) => - val exit = if (status.isOk) Exit.unit else Exit.fail(new StatusException(status, trailers)) - promise.done(exit) *> queue.shutdown - case _ => - buffered.update(_ - chunk.size) *> fetchMore - }).as(chunk) + private def handle(promise: Promise[StatusException, Unit])(chunk: Chunk[ResponseFrame[Res]]) = + ZIO.unlessDiscard(chunk.isEmpty)(chunk.last match { + case ResponseFrame.Trailers(status, trailers) => + val exit = + if (status.isOk) Exit.unit + else Exit.fail(new StatusException(status, trailers)) + promise.done(exit) *> queue.shutdown + case _ => + fetchMore(chunk.size) + }) override def onHeaders(headers: Metadata): Unit = - unsafeRun(queue.offer(ResponseFrame.Headers(headers)) *> increment) + unsafeRun(queue.offer(ResponseFrame.Headers(headers)).unit) override def onMessage(message: Res): Unit = - unsafeRun(queue.offer(ResponseFrame.Message(message)) *> increment *> fetchOne) + unsafeRun(queue.offer(ResponseFrame.Message(message)) *> fetchOne) override def onClose(status: Status, trailers: Metadata): Unit = unsafeRun(queue.offer(ResponseFrame.Trailers(status, trailers)).unit) @@ -45,15 +44,14 @@ class StreamingClientCallListener[Res]( ZStream.fromZIO(Promise.make[StatusException, Unit]).flatMap { promise => ZStream .fromQueue(queue, prefetch.getOrElse(ZStream.DefaultChunkSize)) - .mapChunksZIO(handle(promise)) + .tapChunks(handle(promise)) .concat(ZStream.execute(promise.await)) } } object StreamingClientCallListener { def make[Res](call: ZClientCall[?, Res], prefetch: Option[Int]): UIO[StreamingClientCallListener[Res]] = for { - runtime <- ZIO.runtime[Any] - queue <- Queue.unbounded[ResponseFrame[Res]] - buffered <- Ref.make(0) - } yield new StreamingClientCallListener(prefetch, runtime, call, queue, buffered) + runtime <- ZIO.runtime[Any] + queue <- Queue.unbounded[ResponseFrame[Res]] + } yield new StreamingClientCallListener(prefetch, runtime, call, queue) } diff --git a/e2e/protos/src/main/protobuf/testservice.proto b/e2e/protos/src/main/protobuf/testservice.proto index 9165ff65..c8040b6c 100644 --- a/e2e/protos/src/main/protobuf/testservice.proto +++ b/e2e/protos/src/main/protobuf/testservice.proto @@ -5,38 +5,37 @@ package scalapb.zio_grpc; import "scalapb/scalapb.proto"; message Request { - enum Scenario { - OK = 0; - ERROR_NOW = 1; // fail with an error - ERROR_AFTER = 2; // for server streaming, error after two responses - DELAY = 3; // do not return a response. for testing cancellations - DIE = 4; // fail - UNAVAILABLE = 5; // fail with UNAVAILABLE, to test client retries - } - Scenario scenario = 1; - int32 in = 2; + enum Scenario { + OK = 0; + ERROR_NOW = 1; // fail with an error + ERROR_AFTER = 2; // for server streaming, error after two responses + DELAY = 3; // do not return a response, to test cancellations + LARGE_STREAM = 4; // stream of large elements, to test backpressure + DIE = 5; // fail + UNAVAILABLE = 6; // fail with UNAVAILABLE, to test client retries + } + Scenario scenario = 1; + int32 in = 2; } -message Response { - string out = 1; -} +message Response { string out = 1; } message ResponseTypeMapped { - option (scalapb.message).type = "scalapb.zio_grpc.WrappedString"; + option (scalapb.message).type = "scalapb.zio_grpc.WrappedString"; - string out = 1; + string out = 1; } service TestService { - rpc Unary(Request) returns (Response); + rpc Unary(Request) returns (Response); - rpc UnaryTypeMapped(Request) returns (ResponseTypeMapped); + rpc UnaryTypeMapped(Request) returns (ResponseTypeMapped); - rpc ServerStreaming(Request) returns (stream Response); + rpc ServerStreaming(Request) returns (stream Response); - rpc ServerStreamingTypeMapped(Request) returns (stream ResponseTypeMapped); + rpc ServerStreamingTypeMapped(Request) returns (stream ResponseTypeMapped); - rpc ClientStreaming(stream Request) returns (Response); + rpc ClientStreaming(stream Request) returns (Response); - rpc BidiStreaming(stream Request) returns (stream Response); + rpc BidiStreaming(stream Request) returns (stream Response); } diff --git a/e2e/src/main/scalajvm/scalapb/zio_grpc/TestServiceImpl.scala b/e2e/src/main/scalajvm/scalapb/zio_grpc/TestServiceImpl.scala index a39b9e3a..a26067c7 100644 --- a/e2e/src/main/scalajvm/scalapb/zio_grpc/TestServiceImpl.scala +++ b/e2e/src/main/scalajvm/scalapb/zio_grpc/TestServiceImpl.scala @@ -1,14 +1,11 @@ package scalapb.zio_grpc import scalapb.zio_grpc.testservice.Request -import zio.{Clock, Console, Exit, Promise, ZIO, ZLayer} +import zio._ import scalapb.zio_grpc.testservice.Response import io.grpc.{Status, StatusException} import scalapb.zio_grpc.testservice.Request.Scenario import zio.stream.{Stream, ZStream} -import zio.ZEnvironment - -import java.util.concurrent.atomic.AtomicInteger package object server { @@ -19,15 +16,19 @@ package object server { object TestServiceImpl { class Service( - requestReceived: zio.Promise[Nothing, Unit], - delayReceived: zio.Promise[Nothing, Unit], - exit: zio.Promise[Nothing, Exit[StatusException, Response]] + requestReceived: Promise[Nothing, Unit], + delayReceived: Promise[Nothing, Unit], + exit: Promise[Nothing, Exit[StatusException, Response]], + responseCounter: Ref[Int], + rpcCounter: Ref[Int] )(clock: Clock, console: Console) extends testservice.ZioTestservice.TestService { - val rpcRunsCounter: AtomicInteger = new AtomicInteger(0) - def unary(request: Request): ZIO[Any, StatusException, Response] = - (requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet()) *> (request.scenario match { + // A response of size 100KB to saturate the byte buffers and observe backpressure. + private val largeResponse = Response("*" * 100000) + + def unary(request: Request): ZIO[Any, StatusException, Response] = { + def run(request: Request) = request.scenario match { case Scenario.OK => ZIO.succeed(Response(out = "Res" + request.in.toString)) case Scenario.ERROR_NOW => @@ -37,119 +38,128 @@ package object server { case Scenario.DIE => ZIO.die(new RuntimeException("FOO")) case Scenario.UNAVAILABLE => - ZIO.fail(Status.UNAVAILABLE.withDescription(rpcRunsCounter.get().toString).asException()) + rpcCounter.get.flatMap[Any, StatusException, Nothing] { n => + ZIO.fail(Status.UNAVAILABLE.withDescription(n.toString).asException()) + } case _ => ZIO.fail(Status.UNKNOWN.asException()) - })).onExit(exit.succeed(_)) + } + + (requestReceived.succeed(()) *> rpcCounter.incrementAndGet *> run(request) <* responseCounter.incrementAndGet) + .onExit(exit.succeed(_)) + } def unaryTypeMapped(request: Request): ZIO[Any, StatusException, WrappedString] = unary(request).map(r => WrappedString(r.out)) - def serverStreaming( - request: Request - ): ZStream[Any, StatusException, Response] = - ZStream - .acquireReleaseExitWith(requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet())) { - (_, ex) => - ex.foldExit( - failed => - if (failed.isInterrupted || failed.isInterruptedOnly) - exit.succeed(Exit.fail(Status.CANCELLED.asException())) - else exit.succeed(Exit.fail(Status.UNKNOWN.asException())), - _ => exit.succeed(Exit.succeed(Response())) - ) - } - .flatMap { _ => - request.scenario match { - case Scenario.OK => - ZStream(Response(out = "X1"), Response(out = "X2")) - case Scenario.ERROR_NOW => - ZStream.fail(Status.INTERNAL.withDescription("FOO!").asException()) - case Scenario.ERROR_AFTER => - ZStream(Response(out = "X1"), Response(out = "X2")) ++ ZStream - .fail( - Status.INTERNAL.withDescription("FOO!").asException() - ) - case Scenario.DELAY => - ZStream( - Response(out = "X1"), - Response(out = "X2") - ) ++ ZStream.never - case Scenario.DIE => ZStream.die(new RuntimeException("FOO")) - case _ => ZStream.fail(Status.UNKNOWN.asException()) - } - } + def serverStreaming(request: Request): ZStream[Any, StatusException, Response] = { + def run(request: Request) = request.scenario match { + case Scenario.OK => + ZStream(Response(out = "X1"), Response(out = "X2")) + case Scenario.ERROR_NOW => + ZStream.fail(Status.INTERNAL.withDescription("FOO!").asException()) + case Scenario.ERROR_AFTER => + ZStream(Response(out = "X1"), Response(out = "X2")) ++ + ZStream.fail(Status.INTERNAL.withDescription("FOO!").asException()) + case Scenario.DELAY => + ZStream(Response(out = "X1"), Response(out = "X2")) ++ ZStream.never + case Scenario.LARGE_STREAM => + ZStream.fromIterator(Iterator.fill(100)(largeResponse), 1).orDie + case Scenario.DIE => + ZStream.die(new RuntimeException("FOO")) + case _ => + ZStream.fail(Status.UNKNOWN.asException()) + } + + ZStream.acquireReleaseExitWith(requestReceived.succeed(()) *> rpcCounter.incrementAndGet) { (_, ex) => + ex.foldExit( + { failed => + val status = if (failed.isInterrupted) Status.CANCELLED else Status.UNKNOWN + exit.succeed(Exit.fail(status.asException)) + }, + _ => exit.succeed(Exit.succeed(Response())) + ) + } *> run(request).tapChunks(chunk => responseCounter.update(_ + chunk.size)) + } def serverStreamingTypeMapped(request: Request): ZStream[Any, StatusException, WrappedString] = serverStreaming(request).map(r => WrappedString(r.out)) - def clientStreaming( - request: Stream[StatusException, Request] - ): ZIO[Any, StatusException, Response] = - requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet()) *> - request - .runFoldZIO(0)((state, req) => - req.scenario match { - case Scenario.OK => ZIO.succeed(state + req.in) - case Scenario.DELAY => delayReceived.succeed(()) *> ZIO.never - case Scenario.DIE => ZIO.die(new RuntimeException("foo")) - case Scenario.ERROR_NOW => ZIO.fail((Status.INTERNAL.withDescription("InternalError").asException())) - case _: Scenario => ZIO.fail(Status.UNKNOWN.asException()) - } - ) - .map(r => Response(r.toString)) - .onExit(exit.succeed(_)) + def clientStreaming(request: Stream[StatusException, Request]): ZIO[Any, StatusException, Response] = { + def run(state: Int, request: Request) = request.scenario match { + case Scenario.OK => + ZIO.succeed(state + request.in) + case Scenario.DELAY => + delayReceived.succeed(()) *> ZIO.never + case Scenario.DIE => + ZIO.die(new RuntimeException("foo")) + case Scenario.ERROR_NOW => + ZIO.fail(Status.INTERNAL.withDescription("InternalError").asException()) + case _: Scenario => + ZIO.fail(Status.UNKNOWN.asException()) + } + + requestReceived.succeed(()) *> rpcCounter.incrementAndGet *> request + .runFoldZIO(0)(run) + .map(r => Response(r.toString)) + .zipLeft(responseCounter.incrementAndGet) + .onExit(exit.succeed(_)) + } def bidiStreaming( request: Stream[StatusException, Request] - ): Stream[StatusException, Response] = - (ZStream.fromZIO(requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet())).drain ++ - (request.flatMap { r => - r.scenario match { - case Scenario.OK => - ZStream(Response(r.in.toString)) - .repeat(Schedule.recurs(r.in - 1)) - case Scenario.DELAY => ZStream.never - case Scenario.DIE => ZStream.die(new RuntimeException("FOO")) - case Scenario.ERROR_NOW => - ZStream.fail(Status.INTERNAL.withDescription("Intentional error").asException()) - case _ => - ZStream.fail( - Status.INVALID_ARGUMENT.withDescription(s"Got request: ${r.toProtoString}").asException() - ) - } - } ++ ZStream(Response("DONE"))) - .ensuring(exit.succeed(Exit.succeed(Response())))) - .provideEnvironment(ZEnvironment(clock, console)) + ): Stream[StatusException, Response] = { + def run(request: Request) = request.scenario match { + case Scenario.OK => + ZStream(Response(request.in.toString)).repeat(Schedule.recurs(request.in - 1)) + case Scenario.DELAY => + ZStream.never + case Scenario.DIE => + ZStream.die(new RuntimeException("FOO")) + case Scenario.ERROR_NOW => + ZStream.fail(Status.INTERNAL.withDescription("Intentional error").asException()) + case _ => + ZStream.fail( + Status.INVALID_ARGUMENT.withDescription(s"Got request: ${request.toProtoString}").asException() + ) + } - def awaitReceived = requestReceived.await + ZStream.execute(requestReceived.succeed(()) *> rpcCounter.incrementAndGet) ++ request + .flatMap(run) + .concat(ZStream(Response("DONE"))) + .tapChunks(chunk => responseCounter.update(_ + chunk.size)) + .ensuring(exit.succeed(Exit.succeed(Response()))) + .provideEnvironment(ZEnvironment(clock, console)) + } + def awaitReceived = requestReceived.await def awaitDelayReceived = delayReceived.await - - def awaitExit = exit.await + def awaitExit = exit.await + def responsesSent = responseCounter.get } def make( clock: Clock, console: Console - ): zio.IO[Nothing, TestServiceImpl.Service] = - for { - p1 <- Promise.make[Nothing, Unit] - p2 <- Promise.make[Nothing, Unit] - p3 <- Promise.make[Nothing, Exit[StatusException, Response]] - } yield new Service(p1, p2, p3)(clock, console) - - def makeFromEnv: ZIO[Any, Nothing, Service] = - for { - clock <- ZIO.clock - console <- ZIO.console - service <- make(clock, console) - } yield service + ): IO[Nothing, TestServiceImpl.Service] = for { + p1 <- Promise.make[Nothing, Unit] + p2 <- Promise.make[Nothing, Unit] + p3 <- Promise.make[Nothing, Exit[StatusException, Response]] + c1 <- Ref.make(0) + c2 <- Ref.make(0) + } yield new Service(p1, p2, p3, c1, c2)(clock, console) + + def makeFromEnv: ZIO[Any, Nothing, Service] = for { + clock <- ZIO.clock + console <- ZIO.console + service <- make(clock, console) + } yield service val live: ZLayer[Any, Nothing, TestServiceImpl] = ZLayer.scoped(makeFromEnv) - val any: ZLayer[TestServiceImpl, Nothing, TestServiceImpl] = ZLayer.environment + val any: ZLayer[TestServiceImpl, Nothing, TestServiceImpl] = + ZLayer.environment def awaitReceived: ZIO[TestServiceImpl, Nothing, Unit] = ZIO.environmentWithZIO(_.get.awaitReceived) @@ -159,5 +169,8 @@ package object server { def awaitExit: ZIO[TestServiceImpl, Nothing, Exit[StatusException, Response]] = ZIO.environmentWithZIO(_.get.awaitExit) + + def responsesSent: ZIO[TestServiceImpl, Nothing, Int] = + ZIO.environmentWithZIO(_.get.responsesSent) } } diff --git a/e2e/src/test/scalajvm/scalapb/zio_grpc/TestServiceSpec.scala b/e2e/src/test/scalajvm/scalapb/zio_grpc/TestServiceSpec.scala index 982342b5..3fcbee60 100644 --- a/e2e/src/test/scalajvm/scalapb/zio_grpc/TestServiceSpec.scala +++ b/e2e/src/test/scalajvm/scalapb/zio_grpc/TestServiceSpec.scala @@ -6,7 +6,7 @@ import scalapb.zio_grpc.server.TestServiceImpl import scalapb.zio_grpc.testservice.Request.Scenario import scalapb.zio_grpc.testservice.ZioTestservice.TestServiceClient import scalapb.zio_grpc.testservice._ -import zio.{durationInt, Fiber, Queue, ZIO, ZLayer} +import zio._ import zio.stream.{Stream, ZStream} import zio.test.Assertion._ import zio.test.TestAspect.{flaky, timeout, withLiveClock} @@ -20,7 +20,7 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec { // https://github.com/grpc/proposal/blob/master/A6-client-retries.md val serviceConfig = Map( "methodConfig" -> List( - Map( + Map[String, Object]( "name" -> List(Map("service" -> "scalapb.zio_grpc.TestService", "method" -> "Unary").asJava).asJava, "retryPolicy" -> Map[String, Any]( "maxAttempts" -> "5", @@ -71,21 +71,30 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec { } ) - def serverStreamingSuiteJVM = + def serverStreamingSuiteJVM(backpressure: Boolean) = suite("server streaming request")( test("catches client cancellations") { assertZIO(for { fb <- TestServiceClient - .serverStreaming( - Request(Request.Scenario.DELAY, in = 12) - ) - .runCollect + .serverStreaming(Request(Request.Scenario.DELAY, in = 12)) + .runDrain .fork _ <- TestServiceImpl.awaitReceived _ <- fb.interrupt exit <- TestServiceImpl.awaitExit } yield exit)(fails(hasStatusCode(Status.CANCELLED))) - } + }, + test(if (backpressure) "backpressures" else "does not backpressure") { + assertZIO(for { + _ <- TestServiceClient + .serverStreaming(Request(Request.Scenario.LARGE_STREAM, in = 100)) + .take(5) + // Sleep for real to give the server enough time to send a lot of data + .tap(_ => Live.live(ZIO.sleep(100.millis))) + .runDrain + sent <- TestServiceImpl.responsesSent + } yield sent)(if (backpressure) isLessThan(50) else equalTo(100)) + } @@ TestAspect.ignore // this test doesn't work well on CI because it depends on timing ) def clientStreamingSuite = @@ -265,7 +274,7 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec { unarySuite, unarySuiteJVM, serverStreamingSuite, - serverStreamingSuiteJVM, + serverStreamingSuiteJVM(backpressure = false), clientStreamingSuite, bidiStreamingSuite ).provideSomeLayer[Deps](clientLayer(None)), @@ -273,7 +282,7 @@ object TestServiceSpec extends ZIOSpecDefault with CommonTestServiceSpec { unarySuite, unarySuiteJVM, serverStreamingSuite, - serverStreamingSuiteJVM, + serverStreamingSuiteJVM(backpressure = true), clientStreamingSuite, bidiStreamingSuite ).provideSomeLayer[Deps](clientLayer(Some(2))) diff --git a/examples/fullapp/project/plugins.sbt b/examples/fullapp/project/plugins.sbt index bc7b5b56..26147705 100644 --- a/examples/fullapp/project/plugins.sbt +++ b/examples/fullapp/project/plugins.sbt @@ -8,11 +8,11 @@ val zioGrpcVersion = "0.6.2" libraryDependencies ++= Seq( "com.thesamet.scalapb.zio-grpc" %% "zio-grpc-codegen" % zioGrpcVersion, - "com.thesamet.scalapb" %% "compilerplugin" % "0.11.15" + "com.thesamet.scalapb" %% "compilerplugin" % "0.11.15" ) // For Scala.js: -addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.17.0") +addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.20.1") addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.2.0") diff --git a/project/Versions.scala b/project/Versions.scala index d7957bd7..f6aed2d7 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -1,4 +1,4 @@ object Version { - val zio = "2.1.13" + val zio = "2.1.21" val grpc = "1.64.0" } diff --git a/project/plugins.sbt b/project/plugins.sbt index 33d3879d..231c2db0 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -10,7 +10,7 @@ addSbtPlugin("com.github.sbt" % "sbt-ci-release" % "1.9.3") addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.13.1") -addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.17.0") +addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.20.1") addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.2.0")