diff --git a/kroto-plus-coroutines/build.gradle b/kroto-plus-coroutines/build.gradle index dc16d86..c6b90ef 100644 --- a/kroto-plus-coroutines/build.gradle +++ b/kroto-plus-coroutines/build.gradle @@ -80,6 +80,16 @@ protobuf { } } + test { + jacoco { + // These deprecated extensions are no longer covered by tests since + // they're not referenced in generated sources anymore. They will be + // removed in the future but remain as to provide backwards compatibility + // with existing generated sources. + excludes += ['com/github/marcoferrer/krotoplus/coroutines/server/ServerCallsKt.class'] + } + } + jacoco { toolVersion = "0.8.5" } diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt index b419e29..002c771 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallExts.kt @@ -16,7 +16,9 @@ package com.github.marcoferrer.krotoplus.coroutines.call +import com.github.marcoferrer.krotoplus.coroutines.CALL_OPTION_COROUTINE_CONTEXT import com.github.marcoferrer.krotoplus.coroutines.asContextElement +import io.grpc.CallOptions import io.grpc.ClientCall import io.grpc.MethodDescriptor import io.grpc.Status @@ -88,6 +90,16 @@ internal fun Throwable.toRpcException(): Throwable = internal fun MethodDescriptor<*, *>.getCoroutineName(): CoroutineName = CoroutineName(fullMethodName) +internal fun newRpcScope( + callOptions: CallOptions, + methodDescriptor: MethodDescriptor<*, *>, + grpcContext: io.grpc.Context = io.grpc.Context.current() +): CoroutineScope = newRpcScope( + callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT), + methodDescriptor, + grpcContext +) + internal fun newRpcScope( coroutineContext: CoroutineContext, methodDescriptor: MethodDescriptor<*, *>, diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallReadyObserver.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallReadyObserver.kt new file mode 100644 index 0000000..cd89eca --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/call/CallReadyObserver.kt @@ -0,0 +1,81 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.call + +import io.grpc.stub.CallStreamObserver +import kotlinx.coroutines.channels.Channel + +internal fun CallStreamObserver<*>.newCallReadyObserver(): CallReadyObserver = + CallReadyObserver(this) + +internal class CallReadyObserver( + callStreamObserver: CallStreamObserver<*> +) : Runnable { + + private val notificationChannel = Channel(1) + + private var hasRan = false + + private val callStreamObserver: CallStreamObserver<*> = callStreamObserver + .apply { setOnReadyHandler(this@CallReadyObserver) } + + suspend fun isReady(): Boolean { + // Suspend until the call is ready. + // If the call is cancelled before then, an exception + // will be thrown. + awaitReady() + return true + } + + suspend fun awaitReady() { + // If our handler hasnt run yet we will want to + // suspend immediately since its early enough that + // calls to `callStreamObserver.isReady` will throw + // and NPE + if(!hasRan) + notificationChannel.receive() + // By the time the on ready handler is invoked, calls + // to `callStreamObserver.isReady` could return false + // Here we will continue to poll notifications until + // the call is ready. For more details reference the + // documentation for `callStreamObserver.setOnReadyHandler()` + while(!callStreamObserver.isReady){ + notificationChannel.receive() + } + } + + fun cancel(t: Throwable? = null){ + notificationChannel.close(t) + } + + private fun signalReady() = notificationChannel.offer(READY_TOKEN) + + @Deprecated( + message = "This method should not be called directly", + level = DeprecationLevel.HIDDEN) + override fun run() { + if(!hasRan) { + hasRan = true + } + signalReady() + } + + companion object{ + private object READY_TOKEN + } + +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt index cb9b192..e2541c0 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCalls.kt @@ -18,33 +18,32 @@ package com.github.marcoferrer.krotoplus.coroutines.client import com.github.marcoferrer.krotoplus.coroutines.CALL_OPTION_COROUTINE_CONTEXT import com.github.marcoferrer.krotoplus.coroutines.call.bindScopeCancellationToCall +import com.github.marcoferrer.krotoplus.coroutines.call.completeSafely import com.github.marcoferrer.krotoplus.coroutines.call.newRpcScope import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext import io.grpc.CallOptions import io.grpc.MethodDescriptor -import io.grpc.Status import io.grpc.stub.AbstractStub -import io.grpc.stub.ClientCallStreamObserver import io.grpc.stub.ClientCalls.asyncBidiStreamingCall import io.grpc.stub.ClientCalls.asyncClientStreamingCall import io.grpc.stub.ClientCalls.asyncServerStreamingCall import io.grpc.stub.ClientCalls.asyncUnaryCall import io.grpc.stub.ClientResponseObserver import kotlinx.coroutines.CancellableContinuation -import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Deferred import kotlinx.coroutines.Job -import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.channels.ProducerScope import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.actor import kotlinx.coroutines.flow.buffer import kotlinx.coroutines.flow.callbackFlow -import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.produceIn import kotlinx.coroutines.suspendCancellableCoroutine -import java.util.concurrent.atomic.AtomicBoolean import kotlin.coroutines.resume import kotlin.coroutines.resumeWithException @@ -121,41 +120,27 @@ public fun clientCallServerStreaming( callOptions: CallOptions = CallOptions.DEFAULT ): ReceiveChannel { - val observerAdapter = ResponseObserverChannelAdapter() + val responseObserver = ServerStreamingResponseObserver() val rpcScope = newRpcScope(callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT), method) - val responseFlow = callbackFlow flow@ { - observerAdapter.scope = this + val responseFlow = callbackFlow flow@{ + responseObserver.responseProducerScope = this val call = grpcChannel .newCall(method, callOptions.withCoroutineContext(coroutineContext)) .beforeCancellation { message, cause -> - observerAdapter.beforeCallCancellation(message, cause) + responseObserver.beforeCallCancellation(message, cause) } - val job = coroutineContext[Job]!! - // Start the RPC Call - asyncServerStreamingCall(call, request, observerAdapter) - - // If our parent job is cancelled before we can - // start the call then we need to propagate the - // cancellation to the underlying call - job.invokeOnCompletion { error -> - // Our job can be cancelled after completion due to the inner machinery - // of kotlinx.coroutines.flow.Channels.kt.emitAll(). Its final operation - // after receiving a close is a call to channel.cancelConsumed(cause). - // Even if it doesnt encounter an exception it will cancel with null. - // We will only invoke cancel on the call - if(job.isCancelled && observerAdapter.isActive){ - call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, error) - } - } + asyncServerStreamingCall(call, request, responseObserver) + + bindScopeCompletionToCall(responseObserver) suspendCancellableCoroutine { cont -> // Here we need to handle not only parent job cancellation // but calls to `channel.cancel(...)` as well. cont.invokeOnCancellation { error -> - if (observerAdapter.isActive) { + if (responseObserver.isActive) { call.cancel(MESSAGE_CLIENT_CANCELLED_CALL, error) } } @@ -168,12 +153,12 @@ public fun clientCallServerStreaming( } // Use buffer UNLIMITED so that we dont drop any inbound messages - return flow { emitAll(responseFlow.buffer(Channel.UNLIMITED)) } - .onEach { - if(observerAdapter.isActive){ - observerAdapter.callStreamObserver.request(1) - } + return flow { + responseFlow.buffer(Channel.UNLIMITED).collect{ message -> + emit(message) + responseObserver.callStreamObserver.request(1) } + } // We use buffer RENDEZVOUS on the outer flow so that our // `onEach` operator is only invoked each time a message is // collected instead of each time a message is received from @@ -188,22 +173,26 @@ public fun > T.clientCallBidiStreaming( ): ClientBidiCallChannel = clientCallBidiStreaming(method, channel, callOptions) + + public fun clientCallBidiStreaming( method: MethodDescriptor, channel: io.grpc.Channel, callOptions: CallOptions = CallOptions.DEFAULT ): ClientBidiCallChannel { - val initialContext = callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT) - with(newRpcScope(initialContext, method)) { + val rpcScope = newRpcScope(callOptions, method) + val responseObserver = BidiStreamingResponseObserver(rpcScope) - val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)) - val callChannel = ClientBidiCallChannelImpl(coroutineContext) - asyncBidiStreamingCall(call, callChannel) - bindScopeCancellationToCall(call) + val call = channel + .newCall(method, callOptions.withCoroutineContext(rpcScope.coroutineContext)) + .beforeCancellation { message, cause -> + responseObserver.beforeCallCancellation(message, cause) + } - return callChannel - } + asyncBidiStreamingCall(call, responseObserver) + + return responseObserver.asClientBidiCallChannel() } public fun > T.clientCallClientStreaming( @@ -216,14 +205,62 @@ public fun clientCallClientStreaming( channel: io.grpc.Channel, callOptions: CallOptions = CallOptions.DEFAULT ): ClientStreamingCallChannel { - val initialContext = callOptions.getOption(CALL_OPTION_COROUTINE_CONTEXT) - with(newRpcScope(initialContext, method)) { - val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)) - val callChannel = ClientStreamingCallChannelImpl(coroutineContext) - asyncClientStreamingCall(call, callChannel) - bindScopeCancellationToCall(call) - return callChannel + val rpcScope = newRpcScope(callOptions, method) + val response = CompletableDeferred(parent = rpcScope.coroutineContext[Job]) + val requestChannel = rpcScope.actor(capacity = Channel.RENDEZVOUS) { + val responseObserver = ClientStreamingResponseObserver( + this@actor.channel, response + ) + + val call = channel + .newCall(method, callOptions.withCoroutineContext(coroutineContext)) + .beforeCancellation { message, cause -> + responseObserver.beforeCallCancellation(message, cause) + } + + val requestObserver = asyncClientStreamingCall(call, responseObserver) + + bindScopeCompletionToCall(responseObserver) + + var error: Throwable? = null + try { + val iter = this@actor.channel.iterator() + while(responseObserver.isReady() && iter.hasNext()){ + requestObserver.onNext(iter.next()) + } + } catch (e: Throwable) { + error = e + } finally { + if(responseObserver.isActive) { + requestObserver.completeSafely(error, convertError = false) + } + } + } + + return object : ClientStreamingCallChannel, SendChannel by requestChannel { + override val requestChannel: SendChannel + get() = requestChannel + override val response: Deferred + get() = response } } +internal fun CoroutineScope.bindScopeCompletionToCall( + observer: StatefulClientResponseObserver<*, *> +){ + val job = coroutineContext[Job]!! + // If our parent job is cancelled before we can + // start the call then we need to propagate the + // cancellation to the underlying call + job.invokeOnCompletion { error -> + // Our job can be cancelled after completion due to the inner machinery + // of kotlinx.coroutines.flow.Channels.kt.emitAll(). Its final operation + // after receiving a close is a call to channel.cancelConsumed(cause). + // Even if it doesnt encounter an exception it will cancel with null. + // We will only invoke cancel on the call + if (job.isCancelled && observer.isActive) { + observer.callStreamObserver.cancel(MESSAGE_CLIENT_CANCELLED_CALL, error) + } + } +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseObservers.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseObservers.kt new file mode 100644 index 0000000..5a5f97e --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientResponseObservers.kt @@ -0,0 +1,267 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.client + +import com.github.marcoferrer.krotoplus.coroutines.call.CallReadyObserver +import com.github.marcoferrer.krotoplus.coroutines.call.completeSafely +import com.github.marcoferrer.krotoplus.coroutines.call.newCallReadyObserver +import io.grpc.Status +import io.grpc.stub.ClientCallStreamObserver +import io.grpc.stub.ClientResponseObserver +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ProducerScope +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.channels.actor +import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.flow.buffer +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.produceIn +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.coroutines.coroutineContext + +internal abstract class StatefulClientResponseObserver : ClientResponseObserver { + + protected val isAborted = AtomicBoolean() + + protected val isCompleted = AtomicBoolean() + + val isActive: Boolean + get() = !(isAborted.get() || isCompleted.get()) + + lateinit var callStreamObserver: ClientCallStreamObserver + protected set + + fun cancel(message: String?, cause: Throwable?) { + if(isAborted.compareAndSet(false, true)){ + callStreamObserver.cancel(message, cause) + } + } +} + +internal class ServerStreamingResponseObserver: StatefulClientResponseObserver() { + + lateinit var responseProducerScope: ProducerScope + + override fun beforeStart(requestStream: ClientCallStreamObserver) { + require(::responseProducerScope.isInitialized){ "Producer scope was not initialized" } + callStreamObserver = requestStream.apply { disableAutoInboundFlowControl() } + } + + fun beforeCallCancellation(message: String?, cause: Throwable?){ + if(isAborted.compareAndSet(false, true)) { + if(cause is CancellationException){ + responseProducerScope.cancel(cause) + }else { + val ex = Status.CANCELLED + .withDescription(message) + .withCause(cause) + .asRuntimeException() + responseProducerScope.close(ex) + } + } + } + + override fun onNext(value: RespT) { + responseProducerScope.offer(value) + } + + override fun onError(t: Throwable) { + isAborted.set(true) + responseProducerScope.close(t) + responseProducerScope.cancel(CancellationException(t.message,t)) + } + + override fun onCompleted() { + isCompleted.set(true) + responseProducerScope.close() + } +} + +internal class ClientStreamingResponseObserver( + private val requestChannel: Channel, + private val response: CompletableDeferred +) : StatefulClientResponseObserver() { + + private lateinit var readyObserver: CallReadyObserver + + override fun beforeStart(requestStream: ClientCallStreamObserver) { + callStreamObserver = requestStream.apply { disableAutoInboundFlowControl() } + readyObserver = callStreamObserver.newCallReadyObserver() + } + + fun beforeCallCancellation(message: String?, cause: Throwable?){ + if(isAborted.compareAndSet(false, true)) { + if(cause is CancellationException) { + response.cancel(cause) + }else { + val ex = Status.CANCELLED + .withDescription(message) + .withCause(cause) + .asRuntimeException() + + response.completeExceptionally(ex) + } + } + } + + suspend fun isReady() = readyObserver.isReady() + + override fun onNext(value: RespT) { + response.complete(value) + } + + override fun onError(t: Throwable) { + isAborted.set(true) + if(t is CancellationException){ + requestChannel.cancel(t) + response.cancel(t) + }else{ + requestChannel.close(t) + response.completeExceptionally(t) + } + readyObserver.cancel(t) + } + + override fun onCompleted() { + isCompleted.set(true) + require(response.isCompleted) { + "Stream was completed before onNext was called" + } + } +} + +internal class BidiStreamingResponseObserver( + private val rpcScope: CoroutineScope +): StatefulClientResponseObserver() { + + private lateinit var readyObserver: CallReadyObserver + + private val inboundChannel: Channel = Channel(Channel.UNLIMITED) + + lateinit var requestChannel: SendChannel + private set + + lateinit var responseChannel: ReceiveChannel + private set + + override fun beforeStart(requestStream: ClientCallStreamObserver) { + callStreamObserver = requestStream.apply { disableAutoInboundFlowControl() } + readyObserver = callStreamObserver.newCallReadyObserver() + + requestChannel = rpcScope.actor(capacity = Channel.RENDEZVOUS) { + + var error: Throwable? = null + try { + // We use an iterator to prevent prematurely + // consuming a message from the channel if the + // call is not ready for one. This keeps our + // in-memory buffer from being increased by 1 + val iter = this.channel.iterator() + while(readyObserver.isReady() && iter.hasNext()){ + callStreamObserver.onNext(iter.next()) + } + } catch (e: Throwable) { + error = e + } finally { + if(isActive) { + callStreamObserver.completeSafely(error, convertError = false) + } + } + } + + responseChannel = flow { + var error: Throwable? = null + try { + inboundChannel.consumeEach { message -> + emit(message) + callStreamObserver.request(1) + } + }catch (e: Throwable){ + error = e + throw e + } finally { + if(error != null && coroutineContext[Job]!!.isCancelled && isActive){ + val status = Status.CANCELLED + .withDescription(MESSAGE_CLIENT_CANCELLED_CALL) + .withCause(error) + .asRuntimeException() + requestChannel.close(status) + callStreamObserver.cancel(MESSAGE_CLIENT_CANCELLED_CALL, status) + } + } + } + // We use buffer RENDEZVOUS on the outer flow so that our + // `onEach` operator is only invoked each time a message is + // collected instead of each time a message is received from + // from the underlying call. + .buffer(Channel.RENDEZVOUS) + .produceIn(rpcScope) + } + + fun beforeCallCancellation(message: String?, cause: Throwable?){ + if(isAborted.compareAndSet(false, true)) { + if(cause is CancellationException) { + inboundChannel.cancel(cause) + readyObserver.cancel(cause) + } else { + val ex = Status.CANCELLED + .withDescription(message) + .withCause(cause) + .asRuntimeException() + inboundChannel.close(ex) + readyObserver.cancel(cause) + } + } + } + + override fun onNext(value: RespT) { + inboundChannel.offer(value) + } + + override fun onError(t: Throwable) { + isAborted.set(true) + requestChannel.close(t) + if(t is CancellationException){ + inboundChannel.cancel(t) + }else{ + inboundChannel.close(t) + } + readyObserver.cancel(t) + } + + override fun onCompleted() { + isCompleted.set(true) + inboundChannel.close() + } +} + +internal fun BidiStreamingResponseObserver.asClientBidiCallChannel() + : ClientBidiCallChannel = + object : ClientBidiCallChannel, + SendChannel by requestChannel, + ReceiveChannel by responseChannel { + override val requestChannel: SendChannel + get() = this@asClientBidiCallChannel.requestChannel + override val responseChannel: ReceiveChannel + get() = this@asClientBidiCallChannel.responseChannel + } diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt index 0e9449f..876f8b6 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientStreamingCallChannel.kt @@ -16,19 +16,8 @@ package com.github.marcoferrer.krotoplus.coroutines.client -import com.github.marcoferrer.krotoplus.coroutines.call.MessageHandler -import com.github.marcoferrer.krotoplus.coroutines.call.applyOutboundFlowControl -import com.github.marcoferrer.krotoplus.coroutines.call.attachOutboundChannelCompletionHandler -import io.grpc.stub.ClientCallStreamObserver -import io.grpc.stub.ClientResponseObserver -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred -import kotlinx.coroutines.Job -import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.SendChannel -import kotlin.coroutines.CoroutineContext /** * @@ -42,64 +31,4 @@ public interface ClientStreamingCallChannel : SendChannel { public operator fun component1(): SendChannel = requestChannel public operator fun component2(): Deferred = response -} - - -internal class ClientStreamingCallChannelImpl( - - override val coroutineContext: CoroutineContext, - - private val outboundChannel: Channel = Channel(), - - private val completableResponse: CompletableDeferred = CompletableDeferred(parent = coroutineContext[Job]) - -) : ClientResponseObserver, - ClientStreamingCallChannel, - SendChannel by outboundChannel, - CoroutineScope { - - override val requestChannel: SendChannel - get() = outboundChannel - - override val response: Deferred - get() = completableResponse - - private lateinit var callStreamObserver: ClientCallStreamObserver - - private lateinit var outboundMessageHandler: SendChannel - - override fun beforeStart(requestStream: ClientCallStreamObserver) { - callStreamObserver = requestStream - outboundMessageHandler = applyOutboundFlowControl(requestStream, outboundChannel) - - attachOutboundChannelCompletionHandler( - callStreamObserver, outboundChannel, - onSuccess = { outboundMessageHandler.close() } - ) - completableResponse.invokeOnCompletion { - // If the client prematurely cancels the response - // we need to propagate this as a cancellation to the underlying call - if(!outboundChannel.isClosedForSend && coroutineContext[Job]?.isCancelled == false){ - callStreamObserver.cancel("Client has cancelled call", it) - } - } - } - - override fun onNext(value: RespT) { - completableResponse.complete(value) - } - - override fun onError(t: Throwable) { - outboundChannel.close(t) - outboundChannel.cancel(CancellationException(t.message,t)) - completableResponse.completeExceptionally(t) - outboundMessageHandler.close(t) - } - - override fun onCompleted() { - require(completableResponse.isCompleted){ - "Stream was completed before onNext was called" - } - } - } \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ResponseObserverChannelAdapter.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ResponseObserverChannelAdapter.kt deleted file mode 100644 index bf42ec6..0000000 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ResponseObserverChannelAdapter.kt +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2019 Kroto+ Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.github.marcoferrer.krotoplus.coroutines.client - -import io.grpc.Status -import io.grpc.stub.ClientCallStreamObserver -import io.grpc.stub.ClientResponseObserver -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.cancel -import kotlinx.coroutines.channels.ProducerScope -import java.util.concurrent.atomic.AtomicBoolean - -internal class ResponseObserverChannelAdapter: ClientResponseObserver { - - private val isAborted = AtomicBoolean() - private val isCompleted = AtomicBoolean() - - lateinit var scope: ProducerScope - - lateinit var callStreamObserver: ClientCallStreamObserver - private set - - val isActive: Boolean - get() = !(isAborted.get() || isCompleted.get()) - - override fun beforeStart(requestStream: ClientCallStreamObserver) { - require(::scope.isInitialized){ "Producer scope was not initialized" } - callStreamObserver = requestStream.apply { disableAutoInboundFlowControl() } - } - - fun beforeCallCancellation(message: String?, cause: Throwable?){ - if(!isAborted.getAndSet(true)) { - val cancellationStatus = Status.CANCELLED - .withDescription(message) - .withCause(cause) - .asRuntimeException() - - scope.close(CancellationException(message, cancellationStatus)) - } - } - - override fun onNext(value: RespT) { - scope.offer(value) - } - - override fun onError(t: Throwable) { - isAborted.set(true) - scope.close(t) - scope.cancel(CancellationException(t.message,t)) - } - - override fun onCompleted() { - isCompleted.set(true) - scope.close() - } -} - diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/DeferredCancellationHandler.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/DeferredCancellationHandler.kt new file mode 100644 index 0000000..e011d96 --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/DeferredCancellationHandler.kt @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.server + +import io.grpc.Status +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.cancel +import java.util.concurrent.atomic.AtomicBoolean + +/** + * Used for supporting atomic invocations of server method handlers in cases where they are encapsulated + * within a flow. + * + * TODO(marco): Update usage of atomics to kotlinx.atomicfu once stable + */ +internal class DeferredCancellationHandler (val scope: CoroutineScope) : Runnable { + + private val wasCancelled = AtomicBoolean() + private val handlerStarted = AtomicBoolean() + + override fun run() { + if(handlerStarted.get()){ + cancel() + } + wasCancelled.set(true) + } + + fun onMethodHandlerStart(){ + handlerStarted.set(true) + if(wasCancelled.get()){ + cancel() + } + } + + private fun cancel(){ + scope.cancel(newCancellationException()) + } + + private fun newCancellationException(): CancellationException { + val status = Status.CANCELLED + .withDescription(MESSAGE_SERVER_CANCELLED_CALL) + .asRuntimeException() + + return CancellationException(status.message, status) + } +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/MethodHandlers.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/MethodHandlers.kt new file mode 100644 index 0000000..31ecd77 --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/MethodHandlers.kt @@ -0,0 +1,48 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.server + +import kotlinx.coroutines.channels.ReceiveChannel +import kotlinx.coroutines.channels.SendChannel + +/** + * Adaptor to a unary method. + */ +interface UnaryMethod { + suspend operator fun invoke(request: ReqT): RespT +} + +/** + * Adaptor to a server streaming method. + */ +interface ServerStreamingMethod { + suspend operator fun invoke(request: ReqT, responseChannel: SendChannel) +} + +/** + * Adaptor to a client streaming method. + */ +interface ClientStreamingMethod { + suspend operator fun invoke(requestChannel: ReceiveChannel): RespT +} + +/** + * Adaptor to a bidirectional streaming method. + */ +interface BidiStreamingMethod { + suspend operator fun invoke(requestChannel: ReceiveChannel, responseChannel: SendChannel) +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallHandlers.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallHandlers.kt new file mode 100644 index 0000000..a2e489e --- /dev/null +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallHandlers.kt @@ -0,0 +1,310 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.server + +import com.github.marcoferrer.krotoplus.coroutines.call.applyInboundFlowControl +import com.github.marcoferrer.krotoplus.coroutines.call.bindToClientCancellation +import com.github.marcoferrer.krotoplus.coroutines.call.completeSafely +import com.github.marcoferrer.krotoplus.coroutines.call.newCallReadyObserver +import com.github.marcoferrer.krotoplus.coroutines.call.newRpcScope +import com.github.marcoferrer.krotoplus.coroutines.client.MESSAGE_CLIENT_CANCELLED_CALL +import io.grpc.Metadata +import io.grpc.ServerCall +import io.grpc.ServerCallHandler +import io.grpc.Status +import io.grpc.stub.ServerCallStreamObserver +import io.grpc.stub.ServerCalls +import io.grpc.stub.StreamObserver +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Job +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.ProducerScope +import kotlinx.coroutines.channels.consumeEach +import kotlinx.coroutines.flow.buffer +import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.produceIn +import kotlinx.coroutines.launch +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger + +public fun ServiceScope.unaryServerCallHandler( + methodHandler: UnaryMethod +): ServerCallHandler = + UnaryServerCallHandler(this, methodHandler) + +internal class UnaryServerCallHandler( + private val serviceScope: ServiceScope, + private val methodHandler: UnaryMethod +) : ServerCallHandler { + + override fun startCall( + call: ServerCall, + headers: Metadata + ): ServerCall.Listener { + val delegate = ServerCalls.asyncUnaryCall { request, responseObserver -> + + with(newRpcScope(serviceScope.initialContext, call.methodDescriptor)) rpcScope@{ + bindToClientCancellation(responseObserver as ServerCallStreamObserver<*>) + launch(start = CoroutineStart.ATOMIC) { + try { + responseObserver.onNext(methodHandler(request)) + responseObserver.onCompleted() + } catch (e: Throwable) { + responseObserver.completeSafely(e) + } + } + } + } + + return delegate.startCall(call, headers) + } + +} + +public fun ServiceScope.bidiStreamingServerCallHandler( + methodHandler: BidiStreamingMethod +): ServerCallHandler = + BidiStreamingServerCallHandler(this, methodHandler) + +internal class BidiStreamingServerCallHandler( + private val serviceScope: ServiceScope, + private val methodHandler: BidiStreamingMethod +) : ServerCallHandler { + + override fun startCall(call: ServerCall, headers: Metadata): ServerCall.Listener { + + val delegate = ServerCalls.asyncBidiStreamingCall { responseObserver -> + + val rpcScope = newRpcScope(serviceScope.initialContext, call.methodDescriptor) + val cancellationHandler = DeferredCancellationHandler(rpcScope) + val serverCallObserver = (responseObserver as ServerCallStreamObserver).apply { + disableAutoInboundFlowControl() + setOnCancelHandler(cancellationHandler) + } + + val readyObserver = serverCallObserver.newCallReadyObserver() + val requestObserver = BidiStreamingRequestObserver(rpcScope) + + val responseFlow = callbackFlow responseFlow@{ + requestObserver.responseProducerScope = this@responseFlow + + val responseChannel = channel + + val requestChannel = flow requestFlow@{ + var error: Throwable? = null + try { + requestObserver.inboundChannel.consumeEach { message -> + emit(message) + serverCallObserver.request(1) + } + } catch (e: Throwable) { + error = e + throw e + } finally { + //TODO: Should we perform a null check on error just like the client variant of this call? + if (coroutineContext[Job]!!.isCancelled && requestObserver.isActive) { + val status = Status.CANCELLED + .withDescription(MESSAGE_CLIENT_CANCELLED_CALL) + .withCause(error) + .asRuntimeException() + requestObserver.inboundChannel.close(status) + responseChannel.close(status) + } + } + } + .buffer(Channel.RENDEZVOUS) + .produceIn(rpcScope) + + try { + cancellationHandler.onMethodHandlerStart() + methodHandler(requestChannel, responseChannel) + } finally { + if (!requestChannel.isClosedForReceive) { + requestChannel.cancel() + } + } + } + + rpcScope.launch(start = CoroutineStart.ATOMIC) { + var error: Throwable? = null + try { + // Must request at least 1 message to start the call + serverCallObserver.request(1) + + responseFlow.buffer(Channel.RENDEZVOUS).collect { message -> + serverCallObserver.onNext(message) + readyObserver.awaitReady() + } + serverCallObserver.onCompleted() + } catch (e: Throwable) { + error = e + } finally { + serverCallObserver.completeSafely(error) + } + } + + requestObserver + } + + return delegate.startCall(call, headers) + } + + class BidiStreamingRequestObserver( + val rpcScope: CoroutineScope + ) : StreamObserver { + + lateinit var responseProducerScope: ProducerScope<*> + + val inboundChannel = Channel(Channel.UNLIMITED) + val isAborted = AtomicBoolean() + val isCompleted = AtomicBoolean() + val isActive: Boolean + get() = !(isAborted.get() || isCompleted.get()) + + override fun onNext(value: ReqT) { + inboundChannel.offer(value) + } + + override fun onError(t: Throwable) { + inboundChannel.close(t) + responseProducerScope.close(t) + rpcScope.cancel(CancellationException(t.message, t)) + } + + override fun onCompleted() { + isCompleted.set(true) + inboundChannel.close() + } + } + +} + +public fun ServiceScope.clientStreamingServerCallHandler( + methodHandler: ClientStreamingMethod +): ServerCallHandler = + ClientStreamingServerCallHandler(this, methodHandler) + +internal class ClientStreamingServerCallHandler( + private val serviceScope: ServiceScope, + private val methodHandler: ClientStreamingMethod +) : ServerCallHandler { + + override fun startCall(call: ServerCall, headers: Metadata): ServerCall.Listener { + + val delegate = with(newRpcScope(serviceScope.initialContext, call.methodDescriptor)) rpcScope@{ + ServerCalls.asyncClientStreamingCall { responseObserver -> + + val activeInboundJobCount = AtomicInteger() + val inboundChannel = Channel() + + val serverCallObserver = (responseObserver as ServerCallStreamObserver) + .apply { applyInboundFlowControl(inboundChannel, activeInboundJobCount) } + + bindToClientCancellation(serverCallObserver) + + val requestChannel = ServerRequestStreamChannel( + coroutineContext = coroutineContext, + inboundChannel = inboundChannel, + transientInboundMessageCount = activeInboundJobCount, + callStreamObserver = serverCallObserver, + onErrorHandler = { + // Call cancellation already cancels the coroutine scope + // and closes the response stream. So we dont need to + // do anything in this case. + if (!serverCallObserver.isCancelled) { + this@rpcScope.cancel() + responseObserver.completeSafely(it) + } + } + ) + + launch(start = CoroutineStart.ATOMIC) { + try { + responseObserver.onNext(methodHandler(requestChannel)) + responseObserver.onCompleted() + } catch (e: Throwable) { + responseObserver.completeSafely(e) + } finally { + if (!requestChannel.isClosedForReceive) { + requestChannel.cancel() + } + } + } + + requestChannel + } + } + + return delegate.startCall(call, headers) + } + +} + +public fun ServiceScope.serverStreamingServerCallHandler( + methodHandler: ServerStreamingMethod +): ServerCallHandler = + ServerStreamingServerCallHandler(this, methodHandler) + + +internal class ServerStreamingServerCallHandler( + private val serviceScope: ServiceScope, + private val methodHandler: ServerStreamingMethod +) : ServerCallHandler { + + override fun startCall(call: ServerCall, headers: Metadata): ServerCall.Listener { + + val delegate = ServerCalls.asyncServerStreamingCall { request, responseObserver -> + + val rpcScope = newRpcScope(serviceScope.initialContext, call.methodDescriptor) + val cancellationHandler = DeferredCancellationHandler(rpcScope) + + val serverCallObserver = (responseObserver as ServerCallStreamObserver) + .apply { setOnCancelHandler(cancellationHandler) } + + val readyObserver = serverCallObserver.newCallReadyObserver() + + val responseFlow = callbackFlow { + cancellationHandler.onMethodHandlerStart() + methodHandler(request, channel) + channel.close() + }.buffer(Channel.RENDEZVOUS) + + rpcScope.launch(start = CoroutineStart.ATOMIC) { + var error: Throwable? = null + try { + responseFlow.collect { message -> + serverCallObserver.onNext(message) + readyObserver.awaitReady() + } + serverCallObserver.onCompleted() + } catch (e: Throwable) { + error = e + } finally { + serverCallObserver.completeSafely(error) + } + } + } + + return delegate.startCall(call, headers) + } + +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt index 400a5dc..90cef64 100644 --- a/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt +++ b/kroto-plus-coroutines/src/main/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCalls.kt @@ -39,7 +39,13 @@ import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.launch import java.util.concurrent.atomic.AtomicInteger +internal const val MESSAGE_SERVER_CANCELLED_CALL = "Server has cancelled the call" +private const val DEPRECATION_MESSAGE = + "Delegate based server implementations are deprecated and replaced with ServerCallHandler instances. " + + "This is resolved by re-generating source stubs with Kroto+ v0.7.0 and up" + +@Deprecated(message = DEPRECATION_MESSAGE, level = DeprecationLevel.WARNING) public fun ServiceScope.serverCallUnary( methodDescriptor: MethodDescriptor, responseObserver: StreamObserver, @@ -58,6 +64,7 @@ public fun ServiceScope.serverCallUnary( } } +@Deprecated(message = DEPRECATION_MESSAGE, level = DeprecationLevel.WARNING) public fun ServiceScope.serverCallServerStreaming( methodDescriptor: MethodDescriptor, responseObserver: StreamObserver, @@ -92,6 +99,7 @@ public fun ServiceScope.serverCallServerStreaming( } @UseExperimental(ExperimentalCoroutinesApi::class) +@Deprecated(message = DEPRECATION_MESSAGE, level = DeprecationLevel.WARNING) public fun ServiceScope.serverCallClientStreaming( methodDescriptor: MethodDescriptor, responseObserver: StreamObserver, @@ -141,6 +149,7 @@ public fun ServiceScope.serverCallClientStreaming( @UseExperimental(ExperimentalCoroutinesApi::class) +@Deprecated(message = DEPRECATION_MESSAGE, level = DeprecationLevel.WARNING) public fun ServiceScope.serverCallBidiStreaming( methodDescriptor: MethodDescriptor, responseObserver: StreamObserver, @@ -218,9 +227,9 @@ private fun MethodDescriptor<*, *>.getUnimplementedException(): StatusRuntimeExc * before invoking `onComplete` and closing the call stream. * */ -private fun CoroutineScope.bindScopeCompletionToObserver(streamObserver: StreamObserver<*>) { +internal fun CoroutineScope.bindScopeCompletionToObserver(streamObserver: StreamObserver<*>) { coroutineContext[Job]?.invokeOnCompletion { streamObserver.completeSafely(it) } -} +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt index 9b34404..2f4a30a 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/RpcCallTest.kt @@ -16,14 +16,16 @@ package com.github.marcoferrer.krotoplus.coroutines +import com.github.marcoferrer.krotoplus.coroutines.utils.CALL_TRACE_ENABLED import com.github.marcoferrer.krotoplus.coroutines.utils.ClientCallSpyInterceptor import com.github.marcoferrer.krotoplus.coroutines.utils.RpcStateInterceptor +import io.grpc.BindableService import io.grpc.Channel import io.grpc.ClientCall import io.grpc.MethodDescriptor +import io.grpc.ServerInterceptors import io.grpc.examples.helloworld.GreeterCoroutineGrpc import io.grpc.examples.helloworld.GreeterGrpc -import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.testing.GrpcServerRule import kotlinx.coroutines.CompletableDeferred @@ -58,6 +60,57 @@ abstract class RpcCallTest( @BeforeTest fun setupCall() { callState = RpcStateInterceptor() + CALL_TRACE_ENABLED = true +// mockkObject(Testing) +// +// every { Testing.asyncClientStreamingCallK(any(), any()) } answers answer@ { +// val call = firstArg>() +// val responseObserver = secondArg>() +// +// val reqObserver = ClientCalls.asyncClientStreamingCall(call, object: ClientResponseObserver{ +// override fun onNext(value: RespT) { +// responseObserver.onNext(value) +// } +// +// override fun onError(t: Throwable) { +// println("Client: Response observer onError(${t.toDebugString()})") +// responseObserver.onError(t) +// } +// +// override fun onCompleted() { +// println("Client: Response observer onComplete()") +// responseObserver.onCompleted() +// } +// +// override fun beforeStart(requestStream: ClientCallStreamObserver) { +// responseObserver.beforeStart(requestStream) +// } +// +// } as StreamObserver) +// +// return@answer object : StreamObserver { +// override fun onNext(value: ReqT) { +// reqObserver.onNext(value) +// } +// +// override fun onError(t: Throwable) { +// println("Client: Request observer onError(${t.toDebugString()})") +// reqObserver.onError(t) +// } +// +// override fun onCompleted() { +// println("Client: Request observer onComplete()") +// reqObserver.onCompleted() +// } +// +// } +// } + } + + fun registerService(service: BindableService){ + val interceptedService = ServerInterceptors.intercept(service, callState) + nonDirectGrpcServerRule.serviceRegistry.addService(interceptedService) + grpcServerRule.serviceRegistry.addService(interceptedService) } inner class RpcSpy(channel: Channel) { @@ -71,6 +124,9 @@ abstract class RpcCallTest( val stub = GreeterGrpc.newStub(channel) .withInterceptors(ClientCallSpyInterceptor(_call), callState)!! + val blkStub = GreeterGrpc.newBlockingStub(channel) + .withInterceptors(ClientCallSpyInterceptor(_call), callState)!! + val coStub = GreeterCoroutineGrpc.newStub(channel) .withInterceptors(ClientCallSpyInterceptor(_call), callState)!! @@ -104,10 +160,10 @@ abstract class RpcCallTest( ) : T = try { withTimeout(timeout) { block() } } catch (e: TimeoutCancellationException) { + println(callState.toString()) fail(""" |$message |Timeout after ${timeout}ms - |$callState """.trimMargin()) } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt index 0a90624..e915e34 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallBidiStreamingTests.kt @@ -17,21 +17,18 @@ package com.github.marcoferrer.krotoplus.coroutines.client +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.utils.assertFails +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithCancellation import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus +import com.github.marcoferrer.krotoplus.coroutines.utils.matchStatus import com.github.marcoferrer.krotoplus.coroutines.utils.matchThrowable import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext -import io.grpc.CallOptions -import io.grpc.ClientCall import io.grpc.Status -import io.grpc.examples.helloworld.GreeterCoroutineGrpc import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.stub.StreamObserver -import io.grpc.testing.GrpcServerRule -import io.mockk.every -import io.mockk.spyk import io.mockk.verify import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineStart @@ -42,56 +39,34 @@ import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.delay import kotlinx.coroutines.flow.collect -import kotlinx.coroutines.flow.collectIndexed import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking -import org.junit.Rule import org.junit.Test -import java.util.concurrent.atomic.AtomicInteger -import kotlin.coroutines.CoroutineContext -import kotlin.test.BeforeTest import kotlin.test.assertEquals import kotlin.test.assertFailsWith -class ClientCallBidiStreamingTests { +class ClientCallBidiStreamingTests : + RpcCallTest(GreeterGrpc.getSayHelloStreamingMethod()) { - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() + val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()" - @[Rule JvmField] - var nonDirectGrpcServerRule = GrpcServerRule() - - // @[Rule JvmField] - // public val timeout = CoroutinesTimeout.seconds(COROUTINE_TEST_TIMEOUT) - - - private val methodDescriptor = GreeterGrpc.getSayHelloStreamingMethod() - private val service = spyk(object : GreeterGrpc.GreeterImplBase() {}) - - inner class RpcSpy{ - val stub: GreeterGrpc.GreeterStub - lateinit var call: ClientCall - - init { - val channelSpy = spyk(grpcServerRule.channel) - stub = GreeterGrpc.newStub(channelSpy) - - every { channelSpy.newCall(methodDescriptor, any()) } answers { - spyk(grpcServerRule.channel.newCall(methodDescriptor, secondArg())).also { - this@RpcSpy.call = it - } - } - } + private fun setupServerHandler( + block: (StreamObserver) -> StreamObserver + ){ + registerService(object : GreeterGrpc.GreeterImplBase(){ + override fun sayHelloStreaming(responseObserver: StreamObserver) + : StreamObserver = block(responseObserver) + }) } private fun setupServerHandlerError(){ - every { service.sayHelloStreaming(any()) } answers { - val responseObserver = firstArg>() + setupServerHandler { responseObserver -> object : StreamObserver{ + var reqQty = 0 override fun onNext(value: HelloRequest) { if(reqQty >= 3){ @@ -111,8 +86,7 @@ class ClientCallBidiStreamingTests { } private fun setupServerHandlerSuccess(){ - every { service.sayHelloStreaming(any()) } answers { - val responseObserver = firstArg>() + setupServerHandler { responseObserver -> object : StreamObserver{ var reqQty = 0 override fun onNext(value: HelloRequest) { @@ -129,7 +103,7 @@ class ClientCallBidiStreamingTests { } private fun setupServerHandlerNoop(){ - every { service.sayHelloStreaming(any()) } answers { + setupServerHandler { responseObserver -> object : StreamObserver{ override fun onNext(value: HelloRequest) {} override fun onError(t: Throwable?) {} @@ -138,11 +112,6 @@ class ClientCallBidiStreamingTests { } } - @BeforeTest - fun setupService() { - grpcServerRule.serviceRegistry.addService(service) - } - @Test fun `Call succeeds on server response`() { val rpcSpy = RpcSpy() @@ -152,7 +121,7 @@ class ClientCallBidiStreamingTests { val (requestChannel, responseChannel) = stub .clientCallBidiStreaming(methodDescriptor) - val result = runBlocking(Dispatchers.Default) { + val result = runTest { launch { repeat(3){ requestChannel.send( @@ -166,6 +135,8 @@ class ClientCallBidiStreamingTests { responseChannel.consumeAsFlow().map { it.message }.toList() } + callState.blockUntilClosed() + assertEquals(3,result.size) result.forEachIndexed { index, message -> assertEquals("Req:#$index/Resp:#$index",message) @@ -184,7 +155,7 @@ class ClientCallBidiStreamingTests { val (requestChannel, responseChannel) = stub .clientCallBidiStreaming(methodDescriptor) - runBlocking { + runTest { launch { with(requestChannel){ repeat(4) { @@ -201,7 +172,7 @@ class ClientCallBidiStreamingTests { } } launch { - repeat(2) { + repeat(3) { assertEquals("Req:#$it/Resp:#$it", responseChannel.receive().message) } assertFailsWithStatus(Status.INVALID_ARGUMENT) { @@ -210,6 +181,8 @@ class ClientCallBidiStreamingTests { } } + callState.blockUntilClosed() + assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } @@ -226,15 +199,15 @@ class ClientCallBidiStreamingTests { .clientCallBidiStreaming(methodDescriptor) - runBlocking { - launch(Dispatchers.Default) { + runTest { + launch { val job = launch(start = CoroutineStart.ATOMIC) { launch(start = CoroutineStart.UNDISPATCHED){ - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { responseChannel.receive() } } - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { repeat(3) { requestChannel.send( HelloRequest.newBuilder() @@ -252,6 +225,8 @@ class ClientCallBidiStreamingTests { } } + callState.blockUntilCancellation() + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } @@ -266,7 +241,7 @@ class ClientCallBidiStreamingTests { lateinit var requestChannel: SendChannel lateinit var responseChannel: ReceiveChannel assertFails { - runBlocking { + runTest { launch(start = CoroutineStart.UNDISPATCHED) { val callChannel = stub .withCoroutineContext() @@ -278,7 +253,7 @@ class ClientCallBidiStreamingTests { val job = launch { callChannel.responseChannel.receive().message } - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { repeat(3) { requestChannel.send( HelloRequest.newBuilder() @@ -288,7 +263,7 @@ class ClientCallBidiStreamingTests { delay(5) } } - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { job.join() } } @@ -296,7 +271,15 @@ class ClientCallBidiStreamingTests { } } - verify { rpcSpy.call.cancel(any(), any()) } + runTest { + assertFailsWithCancellation { + responseChannel.receive() + } + } + + callState.blockUntilCancellation() + + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } @@ -310,7 +293,7 @@ class ClientCallBidiStreamingTests { lateinit var requestChannel: SendChannel lateinit var responseChannel: ReceiveChannel assertFailsWith(IllegalStateException::class, "cancel") { - runBlocking { + runBlocking(Dispatchers.Default) { val callChannel = stub .withCoroutineContext() .clientCallBidiStreaming(methodDescriptor) @@ -341,69 +324,17 @@ class ClientCallBidiStreamingTests { } } - verify { rpcSpy.call.cancel(any(), any()) } - assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } - assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } - } + callState.blockUntilCancellation() - @Test - fun `High volume call succeeds`() { - nonDirectGrpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { - override val initialContext: CoroutineContext = Dispatchers.Default - override suspend fun sayHelloStreaming( - requestChannel: ReceiveChannel, - responseChannel: SendChannel - ) { - requestChannel.consumeAsFlow().collectIndexed { index, value -> - if (index % 1000 == 0) { -// println("Server received $index") - } - - responseChannel.send(HelloReply.newBuilder().setMessage(value.name).build()) - } - responseChannel.close() - } - }) - val stub = GreeterCoroutineGrpc.newStub(nonDirectGrpcServerRule.channel) - - val (requestChannel, responseChannel) = stub - .clientCallBidiStreaming(methodDescriptor) - - val numMessages = 100000 - val receivedCount = AtomicInteger() - runBlocking(Dispatchers.Default) { - val req = HelloRequest.newBuilder() - .setName("test").build() - - launch { - repeat(numMessages) { -// if (it % 1000 == 0) println("Client sent $it") - requestChannel.send(req) - } - requestChannel.close() - } - - launch { - repeat(numMessages) { -// if (it % 1000 == 0) println("Client received $it") - responseChannel.receive() - receivedCount.incrementAndGet() - } - } - } - // Sleep so that we can ensure the response channel - // has had enough time to close before being asserted on - Thread.sleep(50) + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } - assertEquals(numMessages, receivedCount.get(), "Must response count must equal request count") } @Test fun `Call is cancelled when request channel closed with error concurrently`() { val rpcSpy = RpcSpy() val stub = rpcSpy.stub - val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()" val expectedException = IllegalStateException("test") setupServerHandlerSuccess() @@ -411,32 +342,32 @@ class ClientCallBidiStreamingTests { .clientCallBidiStreaming(methodDescriptor) val result = mutableListOf() - runBlocking(Dispatchers.Default) { + runTest { launch { - kotlin.runCatching { - repeat(3) { - requestChannel.send( - HelloRequest.newBuilder() - .setName(it.toString()) - .build() - ) - } - requestChannel.close(expectedException) + repeat(3) { + requestChannel.send( + HelloRequest.newBuilder() + .setName(it.toString()) + .build() + ) } + requestChannel.close(expectedException) } - assertFailsWithStatus(Status.CANCELLED,"CANCELLED: $expectedCancelMessage"){ + assertFailsWithStatus(Status.CANCELLED, "CANCELLED: $expectedCancelMessage") { responseChannel.consumeAsFlow() .map { it.message } .collect { result.add(it) } } } + callState.blockUntilCancellation() + assert(result.isNotEmpty()) result.forEachIndexed { index, message -> assertEquals("Req:#$index/Resp:#$index",message) } - verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } + verify(exactly = 1) { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } @@ -445,7 +376,6 @@ class ClientCallBidiStreamingTests { fun `Call is cancelled when request channel closed with error sequentially`() { val rpcSpy = RpcSpy() val stub = rpcSpy.stub - val expectedCancelMessage = "Cancelled by client with StreamObserver.onError()" val expectedException = IllegalStateException("test") setupServerHandlerSuccess() @@ -453,26 +383,28 @@ class ClientCallBidiStreamingTests { .clientCallBidiStreaming(methodDescriptor) val result = mutableListOf() - runBlocking(Dispatchers.Default) { + runTest { requestChannel.send( HelloRequest.newBuilder() .setName(0.toString()) .build() ) + result.add(responseChannel.receive().message) requestChannel.close(expectedException) - assertFailsWithStatus(Status.CANCELLED,"CANCELLED: $expectedCancelMessage"){ + assertFailsWithStatus(Status.CANCELLED, "CANCELLED: $expectedCancelMessage") { responseChannel.consumeAsFlow() .collect { result.add(it.message) } } } + callState.client.cancelled.assertBlocking{ "Client must be cancelled" } assertEquals(1, result.size) result.forEachIndexed { index, message -> assertEquals("Req:#$index/Resp:#$index",message) } - verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } + verify(exactly = 1) { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } @@ -486,7 +418,7 @@ class ClientCallBidiStreamingTests { val (requestChannel, responseChannel) = stub .clientCallBidiStreaming(methodDescriptor) - runBlocking(Dispatchers.Default) { + runTest { launch { assertFailsWithStatus(Status.CANCELLED) { repeat(10) { @@ -507,7 +439,9 @@ class ClientCallBidiStreamingTests { responseChannel.cancel() } - verify { rpcSpy.call.cancel("Cancelled by client with StreamObserver.onError()",any()) } + callState.blockUntilCancellation() + + verify(exactly = 1) { rpcSpy.call.cancel(MESSAGE_CLIENT_CANCELLED_CALL,matchStatus(Status.CANCELLED)) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt index 9db9e25..9218d20 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallClientStreamingTests.kt @@ -17,68 +17,43 @@ package com.github.marcoferrer.krotoplus.coroutines.client +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.utils.assertExEquals -import com.github.marcoferrer.krotoplus.coroutines.utils.assertFails +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithCancellation import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus import com.github.marcoferrer.krotoplus.coroutines.utils.matchThrowable +import com.github.marcoferrer.krotoplus.coroutines.utils.toDebugString import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext -import io.grpc.CallOptions -import io.grpc.ClientCall +import io.grpc.ServerInterceptors import io.grpc.Status import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.stub.StreamObserver -import io.grpc.testing.GrpcServerRule -import io.mockk.every +import io.mockk.coVerify import io.mockk.spyk import io.mockk.verify -import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.delay import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import org.junit.Rule import org.junit.Test import kotlin.test.BeforeTest import kotlin.test.assertEquals import kotlin.test.assertFailsWith -class ClientCallClientStreamingTests { +class ClientCallClientStreamingTests : + RpcCallTest(GreeterGrpc.getSayHelloClientStreamingMethod()) { - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() - - // @[Rule JvmField] - // public val timeout = CoroutinesTimeout.seconds(COROUTINE_TEST_TIMEOUT) - - private val methodDescriptor = GreeterGrpc.getSayHelloClientStreamingMethod() private val service = spyk(object : GreeterGrpc.GreeterImplBase() {}) - inner class RpcSpy{ - val stub: GreeterGrpc.GreeterStub - lateinit var call: ClientCall - - init { - val channelSpy = spyk(grpcServerRule.channel) - stub = GreeterGrpc.newStub(channelSpy) - - every { channelSpy.newCall(methodDescriptor, any()) } answers { - spyk(grpcServerRule.channel.newCall(methodDescriptor, secondArg())).also { - this@RpcSpy.call = it - } - } - } - } - private fun setupServerHandlerError(){ - every { service.sayHelloClientStreaming(any()) } answers { - val responseObserver = firstArg>() + setupServerHandler { responseObserver -> object : StreamObserver{ var reqQty = 0 var responseString = "" @@ -100,16 +75,19 @@ class ClientCallClientStreamingTests { } private fun setupServerHandlerSuccess(){ - every { service.sayHelloClientStreaming(any()) } answers { - val responseObserver = firstArg>() + setupServerHandler { responseObserver -> object : StreamObserver{ var reqQty = 0 var responseString = "" override fun onNext(value: HelloRequest) { + println("Server: onNext(${value.name})") responseString += "Req:#${value.name}/Resp:#${reqQty++}|" } - override fun onError(t: Throwable?) {} + override fun onError(t: Throwable) { + println("Server: onError(${t.toDebugString()})") + } override fun onCompleted() { + println("Server: onCompleted()") responseObserver.onNext(HelloReply.newBuilder() .setMessage(responseString) .build()) @@ -119,8 +97,18 @@ class ClientCallClientStreamingTests { } } + private inline fun setupServerHandler( + crossinline block: (responseObserver: StreamObserver) -> StreamObserver + ){ + registerService(object : GreeterGrpc.GreeterImplBase(){ + override fun sayHelloClientStreaming( + responseObserver: StreamObserver + ): StreamObserver = block(responseObserver) + }) + } + private fun setupServerHandlerNoop(){ - every { service.sayHelloClientStreaming(any()) } answers { + setupServerHandler { object : StreamObserver{ override fun onNext(value: HelloRequest) {} override fun onError(t: Throwable?) {} @@ -131,7 +119,8 @@ class ClientCallClientStreamingTests { @BeforeTest fun setupService() { - grpcServerRule.serviceRegistry.addService(service) + grpcServerRule.serviceRegistry.addService( + ServerInterceptors.intercept(service, callState)) } @Test @@ -142,7 +131,7 @@ class ClientCallClientStreamingTests { setupServerHandlerSuccess() lateinit var requestChannel: SendChannel - runBlocking { + runTest { val (sendChannel, response) = stub .withCoroutineContext() .clientCallClientStreaming(methodDescriptor) @@ -159,8 +148,12 @@ class ClientCallClientStreamingTests { launch{ assertEquals("Req:#0/Resp:#0|Req:#1/Resp:#1|Req:#2/Resp:#2|", response.await().message) } + + callState.client.closed.assert{ "Client call should be closed" } } + callState.blockUntilClosed() + verify(exactly = 0) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } @@ -177,31 +170,36 @@ class ClientCallClientStreamingTests { .clientCallClientStreaming(methodDescriptor) var requestsSent = 0 - runBlocking { + runTest { launch { launch(start = CoroutineStart.UNDISPATCHED) { repeat(2) { requestsSent++ requestChannel.send( HelloRequest.newBuilder() - .setName(it.toString()) - .build() - ) - } - assertFailsWithStatus(Status.INVALID_ARGUMENT) { - requestChannel.send( - HelloRequest.newBuilder() - .setName("request") + .setName(0.toString()) .build() ) } } - assertFailsWithStatus(Status.INVALID_ARGUMENT) { - response.await().message - } } } + assertFailsWithStatus(Status.INVALID_ARGUMENT) { + runTest { response.await().message } + } + assertFailsWithStatus(Status.INVALID_ARGUMENT) { + runTest { + requestChannel.send( + HelloRequest.newBuilder() + .setName("request") + .build() + ) + } + } + + callState.blockUntilClosed() + assertEquals(2,requestsSent) assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } } @@ -217,15 +215,15 @@ class ClientCallClientStreamingTests { .withCoroutineContext(externalJob) .clientCallClientStreaming(methodDescriptor) - runBlocking { + runTest { launch(Dispatchers.Default) { val job = launch { launch(start = CoroutineStart.UNDISPATCHED){ - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { response.await().message } } - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { repeat(3) { delay(5) requestChannel.send( @@ -244,6 +242,8 @@ class ClientCallClientStreamingTests { } } + callState.client.cancelled.assertBlocking { "Client must be cancelled" } + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } } @@ -255,37 +255,42 @@ class ClientCallClientStreamingTests { setupServerHandlerNoop() lateinit var requestChannel: SendChannel - assertFails { - runBlocking { + lateinit var response: Deferred + + assertFailsWithCancellation { + runTest { launch(start = CoroutineStart.UNDISPATCHED) { val callChannel = stub .withCoroutineContext() .clientCallClientStreaming(methodDescriptor) - requestChannel = callChannel.requestChannel + requestChannel = spyk(callChannel.requestChannel) + response = callChannel.response - val job = launch { - callChannel.response.await().message - } - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { repeat(3) { requestChannel.send( HelloRequest.newBuilder() .setName(it.toString()) .build() ) - delay(5) + callState.client.cancelled.await() } } - assertFailsWithStatus(Status.CANCELLED) { - job.join() - } } + callState.client.onReady.await() cancel() } } - verify { rpcSpy.call.cancel(any(), any()) } + callState.client.cancelled.assertBlocking { "Client must be cancelled" } + + assertFailsWithCancellation { + runTest { response.await().message } + } + + coVerify(exactly = 1) { requestChannel.send(any()) } + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } } @@ -298,7 +303,7 @@ class ClientCallClientStreamingTests { lateinit var requestChannel: SendChannel assertFailsWith(IllegalStateException::class, "cancel") { - runBlocking { + runTest { val callChannel = stub .withCoroutineContext() .clientCallClientStreaming(methodDescriptor) @@ -307,7 +312,7 @@ class ClientCallClientStreamingTests { launch(Dispatchers.Default) { launch(start = CoroutineStart.UNDISPATCHED) { - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { repeat(3) { requestChannel.send( HelloRequest.newBuilder() @@ -321,14 +326,16 @@ class ClientCallClientStreamingTests { launch { error("cancel") } - assertFailsWithStatus(Status.CANCELLED) { + assertFailsWithCancellation { callChannel.response.await().message } } } } - verify { rpcSpy.call.cancel(any(), any()) } + callState.client.cancelled.assertBlocking { "Client must be cancelled" } + + verify(exactly = 1) { rpcSpy.call.cancel(any(), any()) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } } @@ -344,18 +351,16 @@ class ClientCallClientStreamingTests { val (requestChannel, response) = stub .clientCallClientStreaming(methodDescriptor) - runBlocking(Dispatchers.Default) { + runTest { launch { - kotlin.runCatching { - repeat(3) { - requestChannel.send( - HelloRequest.newBuilder() - .setName(it.toString()) - .build() - ) - } - requestChannel.close(expectedException) + repeat(3) { + requestChannel.send( + HelloRequest.newBuilder() + .setName(it.toString()) + .build() + ) } + requestChannel.close(expectedException) } assertFailsWithStatus(Status.CANCELLED,"CANCELLED: $expectedCancelMessage"){ @@ -363,6 +368,8 @@ class ClientCallClientStreamingTests { } } + callState.client.cancelled.assertBlocking { "Client must be cancelled" } + verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assertExEquals(expectedException, response.getCompletionExceptionOrNull()?.cause) @@ -381,7 +388,7 @@ class ClientCallClientStreamingTests { val (requestChannel, response) = stub .clientCallClientStreaming(methodDescriptor) - runBlocking(Dispatchers.Default) { + runTest { requestChannel.send( HelloRequest.newBuilder() .setName(0.toString()) @@ -394,6 +401,8 @@ class ClientCallClientStreamingTests { } } + callState.client.cancelled.assertBlocking { "Client must be cancelled" } + verify { rpcSpy.call.cancel(expectedCancelMessage, matchThrowable(expectedException)) } assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } assertExEquals(expectedException, response.getCompletionExceptionOrNull()?.cause) diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt index 132ea02..d7a68b0 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/client/ClientCallServerStreamingTests.kt @@ -37,10 +37,8 @@ import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.stub.StreamObserver -import io.mockk.spyk import io.mockk.verify import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -386,7 +384,7 @@ class ClientCallServerStreamingTests : val result = runTest { delay(100) repeat(3) { - verify(exactly = it + 2) { rpcSpy.call.request(1) } + verify(exactly = it + 1) { rpcSpy.call.request(1) } assertEquals("Request#$it:${expectedRequest.name}", responseChannel.receive().message) delay(10) } @@ -429,7 +427,7 @@ class ClientCallServerStreamingTests : val result = runTest { delay(300) repeat(4) { - verify(exactly = it + 2) { rpcSpy.call.request(1) } + verify(exactly = it + 1) { rpcSpy.call.request(1) } consumedMessages += responseChannel.receive().message delay(10) } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/BidiStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/BidiStreamingTests.kt index b6a891b..d12710f 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/BidiStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/BidiStreamingTests.kt @@ -16,45 +16,46 @@ package com.github.marcoferrer.krotoplus.coroutines.integration +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest +import com.github.marcoferrer.krotoplus.coroutines.server.MESSAGE_SERVER_CANCELLED_CALL +import com.github.marcoferrer.krotoplus.coroutines.utils.CALL_TRACE_ENABLED +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithCancellation import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus +import com.github.marcoferrer.krotoplus.coroutines.utils.suspendForever import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext import io.grpc.Status import io.grpc.examples.helloworld.GreeterCoroutineGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.examples.helloworld.send -import io.grpc.testing.GrpcServerRule import io.mockk.coVerify import io.mockk.spyk -import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.channels.toList -import kotlinx.coroutines.delay +import kotlinx.coroutines.ensureActive +import kotlinx.coroutines.flow.collectIndexed +import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.withTimeout -import org.junit.Rule import org.junit.Test +import java.util.concurrent.atomic.AtomicInteger import kotlin.coroutines.CoroutineContext import kotlin.test.assertEquals +import kotlin.test.assertFalse -class BidiStreamingTests { - - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() - - @[Rule JvmField] - var nonDirectGrpcServerRule = GrpcServerRule() +class BidiStreamingTests : RpcCallTest(GreeterCoroutineGrpc.sayHelloStreamingMethod){ @Test fun `Bidi streaming rendezvous impl completes successfully`() { - nonDirectGrpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + val serverJob = Job() + registerService(object : GreeterCoroutineGrpc.GreeterImplBase(){ override val initialContext: CoroutineContext - get() = Dispatchers.Default + get() = serverJob + Dispatchers.Default override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -68,9 +69,9 @@ class BidiStreamingTests { responseChannel.close() } }) - runBlocking { - + val results = runTest { val stub = GreeterCoroutineGrpc.newStub(nonDirectGrpcServerRule.channel) + .withInterceptors(callState) .withCoroutineContext() val (requestChannel, responseChannel) = stub.sayHelloStreaming() @@ -82,55 +83,117 @@ class BidiStreamingTests { requestChannel.close() } - val results = responseChannel.toList() - assertEquals(9, results.size) - - val expected = "name 0|name 0|name 0" + - "|name 1|name 1|name 1" + - "|name 2|name 2|name 2" - assertEquals( - expected, - results.joinToString(separator = "|") { it.message } - ) + responseChannel.toList() } + + assertEquals(9, results.size) + assertEquals( + "name 0|name 0|name 0" + + "|name 1|name 1|name 1" + + "|name 2|name 2|name 2", + results.joinToString(separator = "|") { it.message } + ) + assertFalse(serverJob.isCancelled, "Server job must not be cancelled") } @Test fun `Client cancellation cancels server rpc scope`() { - nonDirectGrpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ - + val serverJob = Job() + val hasExecuted = CompletableDeferred() + registerService(object : GreeterCoroutineGrpc.GreeterImplBase(){ override val initialContext: CoroutineContext - get() = Dispatchers.Default + get() = serverJob + Dispatchers.Default override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, responseChannel: SendChannel ) { - delay(10000L) + hasExecuted.complete(Unit) + suspendForever("Server") } }) - runBlocking(Dispatchers.Default) { - withTimeout(5000L) { - val stub = GreeterCoroutineGrpc.newStub(nonDirectGrpcServerRule.channel) - .withCoroutineContext() - - val (requestChannel, responseChannel) = stub.sayHelloStreaming() - - val reqChanSpy = spyk(requestChannel) - val reqJob = launch(Dispatchers.Default, start = CoroutineStart.UNDISPATCHED) { - assertFailsWithStatus(Status.CANCELLED) { - repeat(6) { - reqChanSpy.send { name = "name $it" } - } + lateinit var reqChanSpy: SendChannel + runTest(20_000) { + val stub = GreeterCoroutineGrpc.newStub(nonDirectGrpcServerRule.channel) + .withInterceptors(callState) + .withCoroutineContext() + + val (requestChannel, responseChannel) = stub.sayHelloStreaming() + + reqChanSpy = spyk(requestChannel) + launch(Dispatchers.Default) { + callState.server.wasReady.await() + assertFailsWithStatus(Status.CANCELLED) { + repeat(6) { + reqChanSpy.send { name = "name $it" } } } + } + + callState.server.wasReady.await() + responseChannel.cancel() + } + + callState.server.closed.assertBlocking { "Server must be closed" } + + runTest { hasExecuted.await() } + + assertFailsWithCancellation(message = "CANCELLED: $MESSAGE_SERVER_CANCELLED_CALL") { + runTest { serverJob.ensureActive() } + } + assert(serverJob.isCompleted){ "Server job must be completed" } + assert(serverJob.isCancelled){ "Server job must be cancelled" } + coVerify(atMost = 2) { reqChanSpy.send(any()) } + assert(reqChanSpy.isClosedForSend) { "Request channel should be closed after response channel is closed" } + } + + @Test + fun `High volume call succeeds`() { + CALL_TRACE_ENABLED = false + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { + override val initialContext: CoroutineContext = Dispatchers.Default + override suspend fun sayHelloStreaming( + requestChannel: ReceiveChannel, + responseChannel: SendChannel + ) { + requestChannel.consumeAsFlow().collectIndexed { index, value -> + responseChannel.send(HelloReply.newBuilder().setMessage(value.name).build()) + } + responseChannel.close() + } + }) + val stub = GreeterCoroutineGrpc.newStub(nonDirectGrpcServerRule.channel) + .withInterceptors(callState) + + val (requestChannel, responseChannel) = stub.sayHelloStreaming() - responseChannel.cancel() - reqJob.join() + val numMessages = 100000 + val receivedCount = AtomicInteger() + runTest(timeout = 60_000 * 2) { + val req = HelloRequest.newBuilder() + .setName("test").build() - coVerify(exactly = 2) { reqChanSpy.send(any()) } - assert(reqChanSpy.isClosedForSend) { "Request channel should be closed after response channel is closed" } + launch { + repeat(numMessages) { +// if(it % 10_000 == 0) println("Sent: $it") + requestChannel.send(req) + } + requestChannel.close() } + + launch { + repeat(numMessages) { +// if(it % 10_000 == 0) println("Received: $it") + responseChannel.receive() + receivedCount.incrementAndGet() + } + } + + callState.awaitClose(timeout = 60_000) } + + assert(requestChannel.isClosedForSend) { "Request channel should be closed for send" } + assert(responseChannel.isClosedForReceive) { "Response channel should be closed for receive" } + assertEquals(numMessages, receivedCount.get(), "Must response count must equal request count") } } \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt index f36e756..dc06d60 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ClientStreamingBackPressureTests.kt @@ -20,19 +20,14 @@ import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.client.clientCallClientStreaming import com.github.marcoferrer.krotoplus.coroutines.utils.assertExEquals import com.github.marcoferrer.krotoplus.coroutines.utils.assertFails -import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus2 +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus import com.github.marcoferrer.krotoplus.coroutines.utils.invoke import com.github.marcoferrer.krotoplus.coroutines.utils.matchThrowable +import com.github.marcoferrer.krotoplus.coroutines.utils.suspendForever import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext -import io.grpc.CallOptions -import io.grpc.Channel -import io.grpc.ClientCall -import io.grpc.ClientInterceptor -import io.grpc.MethodDescriptor import io.grpc.ServerInterceptors import io.grpc.Status import io.grpc.examples.helloworld.GreeterCoroutineGrpc -import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.mockk.spyk @@ -44,7 +39,6 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.ReceiveChannel -import kotlinx.coroutines.delay import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList @@ -85,8 +79,7 @@ class ClientStreamingBackPressureTests : setupUpServerHandler { requestChannel -> deferredServerChannel.complete(spyk(requestChannel)) - delay(Long.MAX_VALUE) - HelloReply.getDefaultInstance() + suspendForever() } val rpcSpy = RpcSpy() @@ -94,8 +87,7 @@ class ClientStreamingBackPressureTests : val requestCount = AtomicInteger() assertFails { - runBlocking { - + runTest { val (clientRequestChannel, _) = stub .withCoroutineContext(coroutineContext + Dispatchers.Default) .clientCallClientStreaming(methodDescriptor) @@ -112,17 +104,18 @@ class ClientStreamingBackPressureTests : } val serverRequestChannel = deferredServerChannel.await() + callState.server.wasReady.await() + callState.client.onReady.await() repeat(3){ - delay(10L) - assertEquals(it + 1, requestCount.get()) serverRequestChannel.receive() + assertEquals(it + 1, requestCount.get()) } cancel() } } - verify(exactly = 4) { rpcSpy.call.sendMessage(any()) } + verify(atMost = 4) { rpcSpy.call.sendMessage(any()) } } @Test @@ -172,8 +165,8 @@ class ClientStreamingBackPressureTests : val job = coroutineContext[Job]!! job.invokeOnCompletion { serverJob.complete(job) } deferredServerChannel.complete(spyk(requestChannel)) - delay(Long.MAX_VALUE) - HelloReply.getDefaultInstance() + requestChannel.receive() + suspendForever() } val rpcSpy = RpcSpy(nonDirectGrpcServerRule.channel) @@ -192,13 +185,12 @@ class ClientStreamingBackPressureTests : ) requestChannel.close(expectedException) - assertFailsWithStatus2(Status.CANCELLED){ + assertFailsWithStatus(Status.CANCELLED){ response.await() } } callState { - blockUntilCancellation() client.closed.assertBlocking { "Client must be closed" } } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt index 0033f05..d25f2b2 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/ServerStreamingTests.kt @@ -17,7 +17,7 @@ package com.github.marcoferrer.krotoplus.coroutines.integration import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest -import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus2 +import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus import io.grpc.ServerInterceptors import io.grpc.Status import io.grpc.examples.helloworld.GreeterCoroutineGrpc @@ -94,7 +94,7 @@ class ServerStreamingTests : result += responseChannel.receive() } phaser.arrive() - assertFailsWithStatus2(Status.INVALID_ARGUMENT) { + assertFailsWithStatus(Status.INVALID_ARGUMENT) { responseChannel.receive() } } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/UnaryTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/UnaryTests.kt new file mode 100644 index 0000000..1846f04 --- /dev/null +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/integration/UnaryTests.kt @@ -0,0 +1,60 @@ +/* + * Copyright 2019 Kroto+ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.github.marcoferrer.krotoplus.coroutines.integration + +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest +import io.grpc.examples.helloworld.GreeterCoroutineGrpc +import io.grpc.examples.helloworld.HelloReply +import io.grpc.examples.helloworld.HelloRequest +import kotlinx.coroutines.CompletableDeferred +import org.junit.Test +import kotlin.test.assertEquals + + +class UnaryTests : RpcCallTest(GreeterCoroutineGrpc.sayHelloMethod) { + + private val request = HelloRequest.newBuilder().setName("request").build() + private val response = HelloReply.newBuilder().setMessage("reply").build() + + private fun setupServerHandlerNoop(){ + setupServerHandler { + @Suppress("IMPLICIT_NOTHING_AS_TYPE_PARAMETER") + CompletableDeferred().await() + } + } + + private fun setupServerHandler(block: suspend (request: HelloRequest) -> HelloReply){ + grpcServerRule.serviceRegistry.addService(object: GreeterCoroutineGrpc.GreeterImplBase(){ + override suspend fun sayHello(request: HelloRequest): HelloReply = block(request) + }) + } + + @Test + fun `Call succeeds on server response`(){ + val rpcSpy = RpcSpy() + setupServerHandler { request -> + HelloReply.newBuilder().setMessage("${request.name}:reply").build() + } + + val result = runTest { + rpcSpy.coStub.sayHello(request) + } + + assertEquals("request:reply", result.message) + } + +} \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt index 2ebf8dc..8e0b7e4 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallBidiStreamingTests.kt @@ -17,21 +17,18 @@ package com.github.marcoferrer.krotoplus.coroutines.server -import com.github.marcoferrer.krotoplus.coroutines.utils.CancellingClientInterceptor +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.utils.ServerSpy import com.github.marcoferrer.krotoplus.coroutines.utils.matchStatus import com.github.marcoferrer.krotoplus.coroutines.utils.serverRpcSpy +import com.github.marcoferrer.krotoplus.coroutines.utils.suspendForever import com.github.marcoferrer.krotoplus.coroutines.withCoroutineContext -import io.grpc.CallOptions -import io.grpc.ClientCall import io.grpc.Status import io.grpc.examples.helloworld.GreeterCoroutineGrpc import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest -import io.grpc.stub.ClientCalls import io.grpc.stub.StreamObserver -import io.grpc.testing.GrpcServerRule import io.mockk.coVerify import io.mockk.spyk import io.mockk.verify @@ -51,40 +48,21 @@ import kotlinx.coroutines.flow.consumeAsFlow import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.yield -import org.junit.Rule import org.junit.Test import java.util.concurrent.atomic.AtomicBoolean import kotlin.coroutines.CoroutineContext import kotlin.coroutines.coroutineContext import kotlin.test.assertEquals -class ServerCallBidiStreamingTests { +class ServerCallBidiStreamingTests : + RpcCallTest(GreeterGrpc.getSayHelloStreamingMethod()){ - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() - - @[Rule JvmField] - var nonDirectGrpcServerRule = GrpcServerRule() - - - // @[Rule JvmField] - // public val timeout = CoroutinesTimeout.seconds(COROUTINE_TEST_TIMEOUT) - - private val methodDescriptor = GreeterGrpc.getSayHelloStreamingMethod() - private val expectedResponse = HelloReply.newBuilder().setMessage("reply").build() private val responseObserver = spyk>(object : StreamObserver { override fun onNext(value: HelloReply?) {} override fun onError(t: Throwable?) {} override fun onCompleted() {} }) - private fun newCall(): Pair, StreamObserver> { - val call = grpcServerRule.channel - .newCall(methodDescriptor, CallOptions.DEFAULT) - - return call to ClientCalls.asyncBidiStreamingCall(call, responseObserver) - } - private fun StreamObserver.sendRequests(qty: Int) { repeat(qty) { onNext(HelloRequest.newBuilder().setName(it.toString()).build()) @@ -96,7 +74,7 @@ class ServerCallBidiStreamingTests { fun `Server responds successfully on rendezvous requests`() { lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -113,6 +91,7 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) requestObserver.sendRequests(3) @@ -131,7 +110,7 @@ class ServerCallBidiStreamingTests { fun `Server responds successfully on uneven requests to response`() { lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -156,6 +135,7 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) requestObserver.sendRequests(9) @@ -175,7 +155,7 @@ class ServerCallBidiStreamingTests { fun `Server responds with error when exception thrown`() { lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -190,6 +170,7 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) requestObserver.sendRequests(3) @@ -204,7 +185,7 @@ class ServerCallBidiStreamingTests { fun `Server responds with error when response channel closed`() { lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -219,9 +200,13 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) - requestObserver.sendRequests(3) + requestObserver.sendRequests(4) + + callState.blockUntilClosed() + verify(exactly = 1) { responseObserver.onError(matchStatus(Status.INVALID_ARGUMENT)) } verify(exactly = 0) { responseObserver.onNext(any()) } verify(exactly = 0) { responseObserver.onCompleted() } @@ -233,7 +218,7 @@ class ServerCallBidiStreamingTests { fun `Server responds with completion even when no responses are sent`() { lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -249,6 +234,7 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) requestObserver.sendRequests(3) @@ -263,7 +249,7 @@ class ServerCallBidiStreamingTests { fun `Server responds with cancellation when scope cancelled normally`(){ lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + registerService(object : GreeterCoroutineGrpc.GreeterImplBase(){ override val initialContext: CoroutineContext = Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -285,9 +271,13 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) requestObserver.sendRequests(3) + + callState.blockUntilClosed() + callState.server.completed.assertBlocking { "Server must complete" } verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED)) } verify(exactly = 0) { responseObserver.onNext(any()) } verify(exactly = 0) { responseObserver.onCompleted() } @@ -299,8 +289,8 @@ class ServerCallBidiStreamingTests { fun `Server responds with error when scope cancelled exceptionally`(){ lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ - override val initialContext: CoroutineContext = Dispatchers.Unconfined + registerService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + override val initialContext: CoroutineContext = Dispatchers.Default override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, responseChannel: SendChannel @@ -322,9 +312,13 @@ class ServerCallBidiStreamingTests { }) val requestObserver = GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloStreaming(responseObserver) requestObserver.sendRequests(3) + + callState.blockUntilClosed() + verify(exactly = 1) { responseObserver.onError(matchStatus(Status.UNKNOWN)) } verify(exactly = 0) { responseObserver.onNext(any()) } verify(exactly = 0) { responseObserver.onCompleted() } @@ -335,11 +329,12 @@ class ServerCallBidiStreamingTests { @Test fun `Server is cancelled when client sends cancellation`() { + val serverJob = Job() lateinit var serverSpy: ServerSpy lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { - override val initialContext: CoroutineContext = Dispatchers.Unconfined + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { + override val initialContext: CoroutineContext = serverJob + Dispatchers.Default override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, responseChannel: SendChannel @@ -347,16 +342,20 @@ class ServerCallBidiStreamingTests { reqChannel = requestChannel respChannel = responseChannel serverSpy = serverRpcSpy(coroutineContext) - delay(300000L) + suspendForever() } }) - val (call, requestObserver) = newCall() + val rpcSpy = RpcSpy() + val requestObserver = rpcSpy.stub.sayHelloStreaming(responseObserver) requestObserver.sendRequests(3) - call.cancel("test",null) + rpcSpy.call.cancel("test",null) + + runBlocking { serverJob.join() } + callState.blockUntilClosed() assert(serverSpy.job?.isCancelled == true) - verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED,"CANCELLED")) } - assertEquals("Job was cancelled",serverSpy.error?.message) + verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED)) } + assertEquals("CANCELLED: $MESSAGE_SERVER_CANCELLED_CALL",serverSpy.error?.message) assert(reqChannel.isClosedForReceive) { "Request channel should be closed" } assert(respChannel.isClosedForSend) { "Response channel should be closed" } } @@ -364,12 +363,13 @@ class ServerCallBidiStreamingTests { @Test fun `Server is cancelled when client sends error`() { + val serverJob = Job() lateinit var serverSpy: ServerSpy lateinit var reqChannel: ReceiveChannel lateinit var respChannel: SendChannel var requestCount = 0 - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { - override val initialContext: CoroutineContext = Dispatchers.Unconfined + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { + override val initialContext: CoroutineContext = serverJob + Dispatchers.Unconfined override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, responseChannel: SendChannel @@ -380,24 +380,27 @@ class ServerCallBidiStreamingTests { reqChannel.consumeEach { requestCount++ } - - delay(300000L) + suspendForever() } }) - val (_, requestObserver) = newCall() + val rpcSpy = RpcSpy() + val requestObserver = rpcSpy.stub.sayHelloStreaming(responseObserver) requestObserver.apply { onNext(HelloRequest.getDefaultInstance()) onNext(HelloRequest.getDefaultInstance()) onError(Status.DATA_LOSS.asRuntimeException()) } + + runBlocking { serverJob.join() } + assert(serverSpy.job?.isCancelled == true){ "Server job should be cancelled" } assertEquals(2,requestCount, "Server should receive two requests") - assertEquals("Job was cancelled",serverSpy.error?.message) + assertEquals("CANCELLED: $MESSAGE_SERVER_CANCELLED_CALL",serverSpy.error?.message) assert(reqChannel.isClosedForReceive) { "Request channel should be closed" } assert(respChannel.isClosedForSend) { "Response channel should be closed" } verify(exactly = 1) { - responseObserver.onError(matchStatus(Status.CANCELLED,"CANCELLED")) + responseObserver.onError(matchStatus(Status.CANCELLED)) } } @@ -405,8 +408,9 @@ class ServerCallBidiStreamingTests { fun `Server method is at least invoked before being cancelled`(){ val deferredRespChannel = CompletableDeferred>() val deferredCtx = CompletableDeferred() + val rpcSpy = RpcSpy() - nonDirectGrpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { + registerService(object : GreeterCoroutineGrpc.GreeterImplBase() { override val initialContext: CoroutineContext = Dispatchers.Default override suspend fun sayHelloStreaming( requestChannel: ReceiveChannel, @@ -430,36 +434,32 @@ class ServerCallBidiStreamingTests { } }) - runBlocking { - val respObserver = spyk(object: StreamObserver{ - val completed = AtomicBoolean() - override fun onNext(value: HelloReply?) {} - override fun onError(t: Throwable?) { completed.set(true) } - override fun onCompleted() { completed.set(true) } - }) - - val stub = GreeterGrpc.newStub(nonDirectGrpcServerRule.channel) - .withInterceptors(CancellingClientInterceptor) - .withCoroutineContext() + val respObserver = spyk(object: StreamObserver{ + val completed = AtomicBoolean() + override fun onNext(value: HelloReply?) {} + override fun onError(t: Throwable?) { completed.set(true) } + override fun onCompleted() { completed.set(true) } + }) + runTest { // Start the call - val reqObserver = stub.sayHelloStreaming(respObserver) - - // Wait for the server method to be invoked - val serverCtx = deferredCtx.await() + val reqObserver = rpcSpy.stub + .withCoroutineContext() + .sayHelloStreaming(respObserver) - // At this point the server method is suspended. We can send the first message. reqObserver.onNext(HelloRequest.getDefaultInstance()) - - // Once we call `onCompleted` the server scope will be canceled - // because of the CancellingClientInterceptor reqObserver.onCompleted() - // We wait for the server scope to complete before proceeding with assertions - serverCtx[Job]!!.join() + callState.server.intercepted.assert { "Server must be intercepted" } + + rpcSpy.call.cancel("test",null) + callState.server.closed.assert { "Server must complete" } + } + + runBlocking { + val serverCtx = deferredCtx.await() val respChannel = deferredRespChannel.await() - while(!respObserver.completed.get()){} assert(respChannel.isClosedForSend){ "Abandoned response channel should be closed" } verify(exactly = 1) { respObserver.onError(matchStatus(Status.CANCELLED)) } diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt index db96cdb..c2d00ca 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallClientStreamingTests.kt @@ -16,11 +16,11 @@ package com.github.marcoferrer.krotoplus.coroutines.server -import com.github.marcoferrer.krotoplus.coroutines.utils.COROUTINE_TEST_TIMEOUT import com.github.marcoferrer.krotoplus.coroutines.utils.CancellingClientInterceptor import com.github.marcoferrer.krotoplus.coroutines.utils.ServerSpy import com.github.marcoferrer.krotoplus.coroutines.utils.matchStatus import com.github.marcoferrer.krotoplus.coroutines.utils.serverRpcSpy +import com.github.marcoferrer.krotoplus.coroutines.utils.suspendForever import io.grpc.CallOptions import io.grpc.ClientCall import io.grpc.Status @@ -210,8 +210,7 @@ class ServerCallClientStreamingTests { override suspend fun sayHelloClientStreaming(requestChannel: ReceiveChannel): HelloReply { reqChannel = requestChannel serverSpy = serverRpcSpy(coroutineContext) - delay(300000L) - return expectedResponse + suspendForever() } }) @@ -239,8 +238,7 @@ class ServerCallClientStreamingTests { requestCount++ } - delay(300000L) - return expectedResponse + suspendForever() } }) diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt index fafa529..1f8d9cd 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallServerStreamingTests.kt @@ -17,14 +17,17 @@ package com.github.marcoferrer.krotoplus.coroutines.server -import com.github.marcoferrer.krotoplus.coroutines.utils.COROUTINE_TEST_TIMEOUT +import com.github.marcoferrer.krotoplus.coroutines.RpcCallTest import com.github.marcoferrer.krotoplus.coroutines.utils.CancellingClientInterceptor import com.github.marcoferrer.krotoplus.coroutines.utils.ServerSpy import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus import com.github.marcoferrer.krotoplus.coroutines.utils.matchStatus import com.github.marcoferrer.krotoplus.coroutines.utils.serverRpcSpy +import com.github.marcoferrer.krotoplus.coroutines.utils.suspendForever +import com.github.marcoferrer.krotoplus.coroutines.utils.toDebugString import io.grpc.CallOptions import io.grpc.ClientCall +import io.grpc.ClientInterceptors import io.grpc.Status import io.grpc.examples.helloworld.GreeterCoroutineGrpc import io.grpc.examples.helloworld.GreeterGrpc @@ -32,8 +35,6 @@ import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest import io.grpc.stub.ClientCalls import io.grpc.stub.StreamObserver -import io.grpc.testing.GrpcServerRule -import io.mockk.coVerify import io.mockk.spyk import io.mockk.verify import io.mockk.verifyOrder @@ -48,37 +49,31 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.yield -import org.junit.Rule import org.junit.Test import kotlin.coroutines.CoroutineContext import kotlin.coroutines.coroutineContext import kotlin.test.assertEquals -class ServerCallServerStreamingTests { +class ServerCallServerStreamingTests : + RpcCallTest(GreeterGrpc.getSayHelloServerStreamingMethod()) { - @[Rule JvmField] - var grpcServerRule = GrpcServerRule().directExecutor() - - // @[Rule JvmField] - // public val timeout = CoroutinesTimeout.seconds(COROUTINE_TEST_TIMEOUT) - - private val methodDescriptor = GreeterGrpc.getSayHelloServerStreamingMethod() private val request = HelloRequest.newBuilder().setName("abc").build() private val expectedResponse = HelloReply.newBuilder().setMessage("reply").build() private val responseObserver = spyk>(object: StreamObserver{ override fun onNext(value: HelloReply?) { -// println("client:onNext:$value") + println("Client: onNext(${value.toString().trim()})") } - override fun onError(t: Throwable?) { -// println("client:onError:$t") + override fun onError(t: Throwable) { + println("Client: onError(${t.toDebugString()})") } override fun onCompleted() { -// println("client:onComplete") + println("Client: onComplete()") } }) private fun newCall(): ClientCall { - val call = grpcServerRule.channel + val call = ClientInterceptors + .intercept(grpcServerRule.channel, callState) .newCall(methodDescriptor, CallOptions.DEFAULT) ClientCalls.asyncServerStreamingCall(call, request, responseObserver) @@ -100,6 +95,7 @@ class ServerCallServerStreamingTests { }) GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloServerStreaming(request,responseObserver) verifyOrder { @@ -126,6 +122,7 @@ class ServerCallServerStreamingTests { }) GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloServerStreaming(request,responseObserver) verify(exactly = 1) { responseObserver.onError(matchStatus(Status.INVALID_ARGUMENT)) } @@ -141,27 +138,30 @@ class ServerCallServerStreamingTests { grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ // We're using `Dispatchers.Unconfined` so that we can make sure the response was returned // before verifying the result. - override val initialContext: CoroutineContext = Dispatchers.Unconfined + override val initialContext: CoroutineContext = Dispatchers.Default override suspend fun sayHelloServerStreaming( request: HelloRequest, responseChannel: SendChannel ) { deferredRespChannel.complete(responseChannel) coroutineScope { - launch { - delay(5L) - } - cancel() - repeat(3){ - responseChannel.send(expectedResponse) + val parentScope = this + launch(start = CoroutineStart.UNDISPATCHED) { + parentScope.cancel() + delay(10L) + repeat(3){ + responseChannel.send(expectedResponse) + } } } } }) GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloServerStreaming(request,responseObserver) + callState.client.closed.assertBlocking{ "Client call must be closed" } verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED)) } verify(exactly = 0) { responseObserver.onNext(any()) } verify(exactly = 0) { responseObserver.onCompleted() } @@ -170,32 +170,42 @@ class ServerCallServerStreamingTests { } } + fun setupServerHandler( + context: CoroutineContext = Dispatchers.Default, + block: suspend (request: HelloRequest, responseChannel: SendChannel) -> Unit + ){ + registerService(object : GreeterCoroutineGrpc.GreeterImplBase(){ + override val initialContext: CoroutineContext = context + override suspend fun sayHelloServerStreaming( + request: HelloRequest, + responseChannel: SendChannel + ) = block(request, responseChannel) + }) + } + @Test fun `Server responds with error when scope cancelled exceptionally`(){ val deferredRespChannel = CompletableDeferred>() - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase(){ - override val initialContext: CoroutineContext = Dispatchers.Unconfined - override suspend fun sayHelloServerStreaming( - request: HelloRequest, - responseChannel: SendChannel - ) { - deferredRespChannel.complete(responseChannel) - coroutineScope { - launch(start = CoroutineStart.UNDISPATCHED) { - throw Status.INVALID_ARGUMENT.asRuntimeException() - } - repeat(3){ - yield() - responseChannel.send(expectedResponse) - } + setupServerHandler(Dispatchers.Unconfined) { request, responseChannel -> + deferredRespChannel.complete(responseChannel) + coroutineScope { + launch(start = CoroutineStart.UNDISPATCHED) { + throw Status.INVALID_ARGUMENT.asRuntimeException() + } + repeat(3){ + yield() + responseChannel.send(expectedResponse) } } - }) + } GreeterGrpc.newStub(grpcServerRule.channel) + .withInterceptors(callState) .sayHelloServerStreaming(request,responseObserver) + callState.blockUntilClosed() + verify(exactly = 1) { responseObserver.onError(matchStatus(Status.INVALID_ARGUMENT)) } verify(exactly = 0) { responseObserver.onNext(any()) } verify(exactly = 0) { responseObserver.onCompleted() } @@ -217,58 +227,40 @@ class ServerCallServerStreamingTests { ) { respChannel = responseChannel serverSpy = serverRpcSpy(coroutineContext) - delay(300000L) + suspendForever() } }) val call = newCall() call.cancel("test",null) assert(serverSpy.job?.isCancelled == true) - verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED,"CANCELLED")) } - assertEquals("Job was cancelled",serverSpy.error?.message) + verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED,"CANCELLED: $MESSAGE_SERVER_CANCELLED_CALL")) } + assertEquals("CANCELLED: $MESSAGE_SERVER_CANCELLED_CALL",serverSpy.error?.message) assert(respChannel.isClosedForSend){ "Abandoned response channel should be closed"} } @Test fun `Server method is at least invoked before being cancelled`(){ val deferredRespChannel = CompletableDeferred>() - val deferredCtx = CompletableDeferred() - grpcServerRule.serviceRegistry.addService(object : GreeterCoroutineGrpc.GreeterImplBase() { - override val initialContext: CoroutineContext = Dispatchers.Default - override suspend fun sayHelloServerStreaming( - request: HelloRequest, - responseChannel: SendChannel - ) { - val respChan = spyk(responseChannel) - deferredCtx.complete(coroutineContext.apply { - get(Job)!!.invokeOnCompletion { - deferredRespChannel.complete(respChan) - } - }) - delay(10000) - yield() - repeat(3){ - respChan.send { message = "response" } - } - } - }) + val serverJob = Job() + setupServerHandler(serverJob) { _, responseChannel -> + deferredRespChannel.complete(responseChannel) + suspendForever("Server") + } - val stub = GreeterGrpc.newBlockingStub(grpcServerRule.channel) - .withInterceptors(CancellingClientInterceptor) + val rpcSpy = RpcSpy(useDirectExecutor = false) + val stub = rpcSpy.blkStub.withInterceptors(CancellingClientInterceptor) assertFailsWithStatus(Status.CANCELLED,"CANCELLED: test"){ val iter = stub.sayHelloServerStreaming(HelloRequest.getDefaultInstance()) while(iter.hasNext()){} } - runBlocking { - val respChannel = deferredRespChannel.await() - assert(respChannel.isClosedForSend){ "Abandoned response channel should be closed" } - coVerify(exactly = 0) { respChannel.send(any()) } - - val serverCtx = deferredCtx.await() - assert(serverCtx[Job]!!.isCompleted){ "Server job should be completed" } - assert(serverCtx[Job]!!.isCancelled){ "Server job should be cancelled" } + runTest { + serverJob.join() + assert(deferredRespChannel.await().isClosedForSend){ "Abandoned response channel should be closed" } + assert(serverJob.isCompleted){ "Server job should be completed" } + assert(serverJob.isCancelled){ "Server job should be cancelled" } } } } \ No newline at end of file diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallUnaryTests.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallUnaryTests.kt index 59d4bf5..20cd406 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallUnaryTests.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/server/ServerCallUnaryTests.kt @@ -20,6 +20,7 @@ import com.github.marcoferrer.krotoplus.coroutines.utils.ServerSpy import com.github.marcoferrer.krotoplus.coroutines.utils.assertFailsWithStatus import com.github.marcoferrer.krotoplus.coroutines.utils.matchStatus import com.github.marcoferrer.krotoplus.coroutines.utils.serverRpcSpy +import com.github.marcoferrer.krotoplus.coroutines.utils.suspendForever import io.grpc.CallOptions import io.grpc.Channel import io.grpc.ClientCall @@ -149,13 +150,13 @@ class ServerCallUnaryTests { override val initialContext: CoroutineContext = Dispatchers.Default override suspend fun sayHello(request: HelloRequest): HelloReply { serverSpy = serverRpcSpy(coroutineContext) - delay(300000L) - return expectedResponse + suspendForever() } }) val call = newCall() call.cancel("test",null) + assert(serverSpy.job!!.isCancelled){ "Server job must be cancelled" } verify(exactly = 1) { responseObserver.onError(matchStatus(Status.CANCELLED)) diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt index 7694db0..1ff2460 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/Assertions.kt @@ -18,23 +18,24 @@ package com.github.marcoferrer.krotoplus.coroutines.utils import io.grpc.Status import io.grpc.StatusRuntimeException +import kotlinx.coroutines.CancellationException import kotlin.test.assertEquals +import kotlin.test.assertTrue import kotlin.test.fail -inline fun assertFailsWithStatus2( - status: Status, - message: String? = null, - block: () -> Unit -){ +inline fun assertFailsWithCancellation(cause: Throwable? = null, message: String? = null, block: () -> Unit){ try{ block() - fail("Block did not fail") + fail("Cancellation exception was not thrown") }catch (e: Throwable){ - println("assertFailsWithStatus(${e.javaClass}, message: ${e.message})") - assertEquals(StatusRuntimeException::class.java.canonicalName, e.javaClass.canonicalName) - require(e is StatusRuntimeException) - message?.let { assertEquals(it,e.message) } - assertEquals(status.code, e.status.code) + if(e is AssertionError) throw e + assertTrue( + e is CancellationException, + "Expected: CancellationException, Actual: ${e.javaClass.canonicalName}" + ) + message?.let { assertEquals(it, e.message) } + cause?.let { assertExEquals(it, e.cause) } + cause?.cause?.let { assertExEquals(it, e.cause?.cause) } } } @@ -45,19 +46,19 @@ inline fun assertFailsWithStatus( ){ try{ block() - fail("Block did not fail") - }catch (e: StatusRuntimeException){ -// TODO: Fix this in separate PR -// }catch (e: Throwable){ -// assertEquals(StatusRuntimeException::class.java.canonicalName, e.javaClass.canonicalName) -// require(e is StatusRuntimeException) -// println("assertFailsWithStatus(${e.javaClass}, message: ${e.message})") + fail("Expected StatusRuntimeException: $status, but none was thrown") + }catch (e: Throwable){ + if(e is AssertionError) throw e + assertTrue( + e is StatusRuntimeException, + "Expected: StatusRuntimeException, Actual: ${e.javaClass.canonicalName}, with Cause: ${e.cause?.javaClass}" + ) message?.let { assertEquals(it,e.message) } assertEquals(status.code, e.status.code) } } -//Default `assertFailsWith` isn't inline and doesnt support coroutines +// Default `assertFailsWith` isn't inline and doesnt support coroutines inline fun assertFails(message: String? = null, block: ()-> Unit){ try { block() diff --git a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt index ca67b98..ccad882 100644 --- a/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt +++ b/kroto-plus-coroutines/src/test/kotlin/com/github/marcoferrer/krotoplus/coroutines/utils/CallUtils.kt @@ -35,6 +35,7 @@ import io.mockk.spyk import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Job +import kotlinx.coroutines.suspendCancellableCoroutine object CancellingClientInterceptor : ClientInterceptor { override fun interceptCall( @@ -42,8 +43,8 @@ object CancellingClientInterceptor : ClientInterceptor { callOptions: CallOptions?, next: Channel ): ClientCall { - val _call = next.newCall(method,callOptions) - return object : SimpleForwardingClientCall(_call){ + val call = next.newCall(method,callOptions) + return object : SimpleForwardingClientCall(call){ override fun halfClose() { super.halfClose() // Cancel call after we've verified @@ -51,7 +52,6 @@ object CancellingClientInterceptor : ClientInterceptor { cancel("test",null) } } - } } @@ -59,6 +59,7 @@ object CancellingClientInterceptor : ClientInterceptor { class ClientState( val intercepted: CompletableDeferred = CompletableDeferred(), val started: CompletableDeferred = CompletableDeferred(), + val onReady: CompletableDeferred = CompletableDeferred(), val halfClosed: CompletableDeferred = CompletableDeferred(), val closed: CompletableDeferred = CompletableDeferred(), val cancelled: CompletableDeferred = CompletableDeferred() @@ -68,6 +69,7 @@ class ClientState( return "\tClientState(\n" + "\t\tintercepted=${intercepted.stateToString()}, \n" + "\t\tstarted=${started.stateToString()},\n" + + "\t\tonReady=${started.stateToString()},\n" + "\t\thalfClosed=${halfClosed.stateToString()},\n" + "\t\tclosed=${closed.stateToString()},\n" + "\t\tcancelled=${cancelled.stateToString()}\n" + @@ -147,11 +149,17 @@ class ClientStateInterceptor(val state: ClientState) : ClientInterceptor { } override fun start(responseListener: Listener, headers: Metadata) { - println("Client: Call start()") + log("Client: Call start()") super.start(object : SimpleForwardingClientCallListener(responseListener){ + override fun onReady() { + log("Client: Call Listener onReady()") + super.onReady() + state.onReady.complete() + } + override fun onClose(status: Status?, trailers: Metadata?) { - println("Client: Call Listener onClose(${status?.toDebugString()})") + log("Client: Call Listener onClose(${status?.toDebugString()})") super.onClose(status, trailers) state.closed.complete() } @@ -161,13 +169,13 @@ class ClientStateInterceptor(val state: ClientState) : ClientInterceptor { } override fun halfClose() { - println("Client: Call halfClose()") + log("Client: Call halfClose()") super.halfClose() state.halfClosed.complete() } override fun cancel(message: String?, cause: Throwable?) { - println("Client: Call cancel(message=$message, cause=${cause?.toDebugString()})") + log("Client: Call cancel(message=$message, cause=${cause?.toDebugString()})") super.cancel(message, cause) state.cancelled.complete() } @@ -185,7 +193,7 @@ class ServerStateInterceptor(val state: ServerState) : ServerInterceptor { val interceptedCall = object : SimpleForwardingServerCall(call){ override fun close(status: Status?, trailers: Metadata?) { - println("Server: Call Close, ${status?.toDebugString()}") + log("Server: Call Close, ${status?.toDebugString()}") super.close(status, trailers) state.closed.complete() } @@ -197,25 +205,25 @@ class ServerStateInterceptor(val state: ServerState) : ServerInterceptor { } override fun onReady() { - println("Server: Call Listener onReady()") + log("Server: Call Listener onReady()") super.onReady() state.wasReady.complete() } override fun onHalfClose() { - println("Server: Call Listener onHalfClose()") + log("Server: Call Listener onHalfClose()") super.onHalfClose() state.halfClosed.complete() } override fun onComplete() { - println("Server: Call Listener onComplete()") + log("Server: Call Listener onComplete()") super.onComplete() state.completed.complete() } override fun onCancel() { - println("Server: Call Listener onCancel()") + log("Server: Call Listener onCancel()") super.onCancel() state.cancelled.complete() } @@ -223,10 +231,10 @@ class ServerStateInterceptor(val state: ServerState) : ServerInterceptor { } } -private fun Throwable.toDebugString(): String = +fun Throwable.toDebugString(): String = "(${this.javaClass.canonicalName}, ${this.message})" -private fun Status.toDebugString(): String = +fun Status.toDebugString(): String = "Status{code=$code, description=$description, cause=${cause?.toDebugString()}}" @@ -246,3 +254,13 @@ fun newCancellingInterceptor(useNormalCancellation: Boolean) = object : ClientIn } } +suspend fun suspendForever(target: String=""): Nothing = suspendCancellableCoroutine { + it.invokeOnCancellation { log("$target was cancelled") } +} + + +var CALL_TRACE_ENABLED = true +// Temporary log util +fun log(message: String){ + if(CALL_TRACE_ENABLED) println(message) +} diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGenerator.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGenerator.kt index 1dfc4ed..46951da 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGenerator.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/GrpcCoroutinesGenerator.kt @@ -17,8 +17,10 @@ package com.github.marcoferrer.krotoplus.generators import com.github.marcoferrer.krotoplus.generators.Generator.Companion.AutoGenerationDisclaimer +import com.github.marcoferrer.krotoplus.generators.builders.GrpcMethodHandlerBuilder import com.github.marcoferrer.krotoplus.generators.builders.GrpcServiceBaseImplBuilder import com.github.marcoferrer.krotoplus.generators.builders.GrpcStubBuilder +import com.github.marcoferrer.krotoplus.generators.builders.methodDefinitionPropName import com.github.marcoferrer.krotoplus.generators.builders.outerObjectName import com.github.marcoferrer.krotoplus.generators.builders.stubClassName import com.github.marcoferrer.krotoplus.proto.ProtoMessage @@ -48,7 +50,7 @@ object GrpcCoroutinesGenerator : Generator { get() = context.config.grpcCoroutinesList.isNotEmpty() private val stubBuilder = GrpcStubBuilder(context) - + private val methodHandlerBuilder = GrpcMethodHandlerBuilder(context) private val serviceBaseImplBuilder = GrpcServiceBaseImplBuilder(context) override fun invoke(): PluginProtos.CodeGeneratorResponse { @@ -111,6 +113,8 @@ object GrpcCoroutinesGenerator : Generator { ) .addType(stubBuilder.buildStub(this)) .addType(serviceBaseImplBuilder.build(this)) + .addProperties(methodHandlerBuilder.buildMethodIdConsts(this)) + .addType(methodHandlerBuilder.buildMethodHandlersTypeSpec(this)) .addProperty( PropertySpec.builder("SERVICE_NAME", String::class.asClassName()) .addModifiers(KModifier.CONST) @@ -120,9 +124,10 @@ object GrpcCoroutinesGenerator : Generator { .addProperties(buildMethodDefinitionProps()) .build() + private fun ProtoService.buildMethodDefinitionProps(): List = methodDefinitions.map { method -> - val propName = "${method.descriptorProto.name.toUpperCamelCase().decapitalize()}Method" + val propName = method.methodDefinitionPropName val propTypeName = CommonClassNames.grpcMethodDescriptor .parameterizedBy(method.requestClassName, method.responseClassName) val propGetter = FunSpec.getterBuilder() diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcBuilderExts.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcBuilderExts.kt index ca22ed2..d1b8dd9 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcBuilderExts.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcBuilderExts.kt @@ -52,6 +52,9 @@ internal val ProtoService.stubClassName: ClassName internal val ProtoMethod.idPropertyName: String get() = "METHODID_${name.toUpperSnakeCase()}" +internal val ProtoMethod.methodDefinitionPropName: String + get() = "${descriptorProto.name.toUpperCamelCase().decapitalize()}Method" + internal fun FunSpec.Builder.addResponseObserverParameter(responseClassName: ClassName): FunSpec.Builder = apply { addParameter("responseObserver", CommonClassNames.streamObserver.parameterizedBy(responseClassName)) } diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcMethodHandlerBuilder.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcMethodHandlerBuilder.kt index c65375c..2de7512 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcMethodHandlerBuilder.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcMethodHandlerBuilder.kt @@ -17,9 +17,13 @@ package com.github.marcoferrer.krotoplus.generators.builders import com.github.marcoferrer.krotoplus.generators.GeneratorContext +import com.github.marcoferrer.krotoplus.proto.ProtoMethod import com.github.marcoferrer.krotoplus.proto.ProtoService import com.github.marcoferrer.krotoplus.utils.CommonClassNames +import com.github.marcoferrer.krotoplus.utils.CommonPackages +import com.squareup.kotlinpoet.AnnotationSpec import com.squareup.kotlinpoet.ClassName +import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier import com.squareup.kotlinpoet.ParameterSpec @@ -27,6 +31,7 @@ import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.TypeVariableName +import com.squareup.kotlinpoet.UNIT import com.squareup.kotlinpoet.asClassName /** @@ -34,73 +39,232 @@ import com.squareup.kotlinpoet.asClassName */ class GrpcMethodHandlerBuilder(val context: GeneratorContext) { + private val reqTypeVarName = TypeVariableName("Req") + private val respTypeVarName = TypeVariableName("Resp") + private val suppressUncheckedAnnotation = AnnotationSpec.builder(Suppress::class.asClassName()) + .addMember("\"UNCHECKED_CAST\"") + .build() + + private val requestParameter = ParameterSpec.builder("request", reqTypeVarName).build() + + private val requestChannelParameter = ParameterSpec.builder( + "requestChannel", CommonClassNames.receiveChannel.parameterizedBy(reqTypeVarName)).build() + + private val responseChannelParameter = ParameterSpec.builder( + "responseChannel", CommonClassNames.sendChannel.parameterizedBy(respTypeVarName)).build() + + object CallHandlerClassNames { + val unary = ClassName(CommonPackages.krotoCoroutineLib+".server","UnaryMethod") + val serverStreaming = ClassName(CommonPackages.krotoCoroutineLib+".server","ServerStreamingMethod") + val clientStreaming = ClassName(CommonPackages.krotoCoroutineLib+".server","ClientStreamingMethod") + val bidiStreaming = ClassName(CommonPackages.krotoCoroutineLib+".server","BidiStreamingMethod") + } + fun buildMethodIdConsts(protoService: ProtoService): List = with(protoService) { - methodDefinitions.mapIndexed { index, method -> + methodDefinitions.map { method -> PropertySpec.builder( method.idPropertyName, Int::class.asClassName() ) .addModifiers(KModifier.PRIVATE, KModifier.CONST) - .initializer("%L", index) + .initializer("%L", method.index) .build() } } - fun buildMethodHandler(protoService: ProtoService): TypeSpec = with(protoService) { - val reqTypeVarName = TypeVariableName("Req") - val respTypeVarName = TypeVariableName("Resp") + private val ProtoService.serviceImplClassName: ClassName + get() = ClassName(protoFile.javaPackage, outerObjectName, baseImplName) + + private fun baseInvokeImplBuilder(): FunSpec.Builder = + FunSpec.builder("invoke") + .addAnnotation(suppressUncheckedAnnotation) + .addModifiers(KModifier.OVERRIDE, KModifier.SUSPEND, KModifier.OPERATOR) + + private fun buildUnaryInvokeImpl(methods: List): FunSpec = + baseInvokeImplBuilder() + .addParameter(requestParameter) + .returns(respTypeVarName) + .addCode(CodeBlock.builder() + .beginControlFlow("return when(methodId)") + .apply { + methods.forEach { method -> + addStatement( + "%N -> serviceImpl.%N(request as %T) as %T", + method.idPropertyName, + method.functionName, + method.requestClassName, + respTypeVarName + ) + addStatement("\n") + } + } + .addStatement("else -> throw %T()", AssertionError::class.asClassName()) + .endControlFlow() + .build()) + .build() + + private fun buildClientStreamingInvokeImpl(methods: List): FunSpec = + baseInvokeImplBuilder() + .addParameter(requestChannelParameter) + .returns(respTypeVarName) + .addCode(CodeBlock.builder() + .beginControlFlow("return when(methodId)") + .apply { + methods.forEach { method -> + addStatement( + "%N -> serviceImpl.%N(requestChannel as %T) as %T", + method.idPropertyName, + method.functionName, + CommonClassNames.receiveChannel.parameterizedBy(method.requestClassName), + respTypeVarName + ) + addStatement("\n") + } + } + .addStatement("else -> throw %T()", AssertionError::class.asClassName()) + .endControlFlow() + .build()) + .build() - val serviceImplClassName = ClassName(protoFile.javaPackage, outerObjectName, baseImplName) + private fun buildServerStreamingInvokeImpl(methods: List): FunSpec = + baseInvokeImplBuilder() + .addParameter(requestParameter) + .addParameter(responseChannelParameter) + .returns(UNIT) + .addCode(CodeBlock.builder() + .beginControlFlow("return when(methodId)") + .apply { + methods.forEach { method -> + add(CodeBlock.of( + """ + |%N -> serviceImpl.%N( + | request as %T, + | responseChannel as %T + |) + """.trimMargin(), + method.idPropertyName, + method.functionName, + method.requestClassName, + CommonClassNames.sendChannel.parameterizedBy(method.responseClassName) + )) + addStatement("\n") + } + } + .addStatement("else -> throw %T()", AssertionError::class.asClassName()) + .endControlFlow() + .build()) + .build() + + private fun buildBidiStreamingInvokeImpl(methods: List): FunSpec = + baseInvokeImplBuilder() + .addParameter(requestChannelParameter) + .addParameter(responseChannelParameter) + .returns(UNIT) + .addCode(CodeBlock.builder() + .beginControlFlow("return when(methodId)") + .apply { + methods.forEach { method -> + add(CodeBlock.of( + """ + |%N -> serviceImpl.%N( + | requestChannel as %T, + | responseChannel as %T + |) + """.trimMargin(), + method.idPropertyName, + method.functionName, + CommonClassNames.receiveChannel.parameterizedBy(method.requestClassName), + CommonClassNames.sendChannel.parameterizedBy(method.responseClassName) + )) + addStatement("\n") + } + } + .addStatement("else -> throw %T()", AssertionError::class.asClassName()) + .endControlFlow() + .build()) + .build() - TypeSpec.classBuilder("MessageHandler") + fun buildMethodHandlersTypeSpec(protoService: ProtoService): TypeSpec = with(protoService) { + + val unaryMethods = methodDefinitions.filter { it.isUnary } + val clientStreamingMethods = methodDefinitions.filter { it.isClientStream } + val serverStreamingMethods = methodDefinitions.filter { it.isServerStream } + val bidiStreamingMethods = methodDefinitions.filter { it.isBidi } + + TypeSpec.classBuilder("MethodHandlers") .addModifiers(KModifier.PRIVATE) .addTypeVariable(reqTypeVarName) .addTypeVariable(respTypeVarName) - .addSuperinterface( - CommonClassNames.GrpcServerCallHandler.unary - .parameterizedBy(reqTypeVarName, respTypeVarName) - ) - .addSuperinterface( - CommonClassNames.GrpcServerCallHandler.serverStreaming - .parameterizedBy(reqTypeVarName, respTypeVarName) - ) - .addSuperinterface( - CommonClassNames.GrpcServerCallHandler.clientStreaming - .parameterizedBy(reqTypeVarName, respTypeVarName) - ) - .addSuperinterface( - CommonClassNames.GrpcServerCallHandler.bidiStreaming - .parameterizedBy(reqTypeVarName, respTypeVarName) - ) - .primaryConstructor( - FunSpec.constructorBuilder() - .addParameter( - ParameterSpec.builder("serviceImpl", serviceImplClassName) - .addModifiers(KModifier.PRIVATE) - .addAnnotation(JvmField::class.java.asClassName()) - .build() + .apply { + if(unaryMethods.isNotEmpty()){ + addSuperinterface( + CallHandlerClassNames.unary + .parameterizedBy(reqTypeVarName, respTypeVarName) + ) + addFunction(buildUnaryInvokeImpl(unaryMethods)) + } + } + .apply { + if(clientStreamingMethods.isNotEmpty()){ + addSuperinterface( + CallHandlerClassNames.clientStreaming + .parameterizedBy(reqTypeVarName, respTypeVarName) ) - .addParameter( - ParameterSpec.builder("methodId", Int::class.asClassName()) - .addModifiers(KModifier.PRIVATE) - .addAnnotation(JvmField::class.java.asClassName()) - .build() + addFunction(buildClientStreamingInvokeImpl(clientStreamingMethods)) + } + } + .apply { + if(serverStreamingMethods.isNotEmpty()){ + addSuperinterface( + CallHandlerClassNames.serverStreaming + .parameterizedBy(reqTypeVarName, respTypeVarName) ) + addFunction(buildServerStreamingInvokeImpl(serverStreamingMethods)) + } + } + .apply { + if(bidiStreamingMethods.isNotEmpty()){ + addSuperinterface( + CallHandlerClassNames.bidiStreaming + .parameterizedBy(reqTypeVarName, respTypeVarName) + ) + addFunction(buildBidiStreamingInvokeImpl(bidiStreamingMethods)) + } + } + .addConstructorAndProps(protoService) + .build() + } + + private fun TypeSpec.Builder.addConstructorAndProps(protoService: ProtoService): TypeSpec.Builder { + + return primaryConstructor( + FunSpec.constructorBuilder() + .addParameter( + ParameterSpec.builder("serviceImpl", protoService.serviceImplClassName) + .build() + ) + .addParameter( + ParameterSpec.builder("methodId", Int::class.asClassName()) + .build() + ) + .build() + ) + .addProperty( + PropertySpec + .builder("serviceImpl", protoService.serviceImplClassName) + .addModifiers(KModifier.PRIVATE) + .addAnnotation(JvmField::class.java.asClassName()) + .initializer("serviceImpl") .build() ) - .addFunction( - FunSpec.builder("invoke") - .addParameter("request", reqTypeVarName) - .addParameter( - "responseObserver", - CommonClassNames.streamObserver.parameterizedBy(respTypeVarName) - ) - .apply { - // TODO: Build invoke impl for unary / server streaming - // TODO: Build invoke impl for client / bidi streaming - } + .addProperty( + PropertySpec + .builder("methodId", Int::class.asClassName()) + .addModifiers(KModifier.PRIVATE) + .addAnnotation(JvmField::class.java.asClassName()) + .initializer("methodId") .build() ) - .build() } + } \ No newline at end of file diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcServiceBaseImplBuilder.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcServiceBaseImplBuilder.kt index d20ec40..d6c8928 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcServiceBaseImplBuilder.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/generators/builders/GrpcServiceBaseImplBuilder.kt @@ -21,8 +21,10 @@ import com.github.marcoferrer.krotoplus.generators.GeneratorContext import com.github.marcoferrer.krotoplus.proto.ProtoMethod import com.github.marcoferrer.krotoplus.proto.ProtoService import com.github.marcoferrer.krotoplus.utils.CommonClassNames +import com.github.marcoferrer.krotoplus.utils.CommonPackages import com.github.marcoferrer.krotoplus.utils.messageBuilderValueCodeBlock import com.squareup.kotlinpoet.AnnotationSpec +import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.CodeBlock import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier @@ -33,32 +35,18 @@ import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.UNIT import com.squareup.kotlinpoet.asClassName import io.grpc.MethodDescriptor +import io.grpc.ServerServiceDefinition class GrpcServiceBaseImplBuilder(val context: GeneratorContext){ fun build(protoService: ProtoService): TypeSpec = with(protoService){ - val delegateValName = "delegate" - TypeSpec.classBuilder(baseImplName) .addKdoc(attachedComments) .addModifiers(KModifier.ABSTRACT) .addSuperinterface(CommonClassNames.bindableService) .addSuperinterface(CommonClassNames.serviceScope) - .addProperty( - PropertySpec - .builder(delegateValName, serviceDelegateClassName) - .addModifiers(KModifier.PRIVATE) - .initializer("%T()", serviceDelegateClassName) - .build() - ) - .addFunction( - FunSpec.builder("bindService") - .addModifiers(KModifier.OVERRIDE) - .returns(CommonClassNames.grpcServerServiceDefinition) - .addCode("return %N.bindService()", delegateValName) - .build() - ) + .addFunction(buildBindServiceFunSpec(protoService)) .apply { for(method in methodDefinitions) when(method.type){ MethodDescriptor.MethodType.UNARY -> addFunction(buildUnaryBaseImpl(method)) @@ -69,7 +57,6 @@ class GrpcServiceBaseImplBuilder(val context: GeneratorContext){ } } .addFunctions(buildResponseLambdaOverloads()) - .addType(buildServiceBaseImplDelegate(protoService)) .build() } @@ -208,137 +195,47 @@ class GrpcServiceBaseImplBuilder(val context: GeneratorContext){ .build() } - // Service impl delegate - - private fun buildServiceBaseImplDelegate(protoService: ProtoService): TypeSpec = with(protoService) { - TypeSpec.classBuilder(serviceDelegateName) - .addModifiers(KModifier.PRIVATE, KModifier.INNER) - .superclass(serviceJavaBaseImplClassName) - .apply { - for(method in methodDefinitions) when(method.type){ - MethodDescriptor.MethodType.UNARY -> addFunction(buildUnaryDelegate(method)) - MethodDescriptor.MethodType.SERVER_STREAMING -> addFunction(buildServerStreamingDelegate(method)) - MethodDescriptor.MethodType.CLIENT_STREAMING -> addFunction(buildClientStreamingDelegate(method)) - MethodDescriptor.MethodType.BIDI_STREAMING -> addFunction(buildBidiStreamingDelegate(method)) - MethodDescriptor.MethodType.UNKNOWN -> throw IllegalStateException("Unknown method type") - } - } - .build() - } - - private fun buildUnaryDelegate(protoMethod: ProtoMethod): FunSpec = with(protoMethod){ - FunSpec.builder(functionName) + private fun buildBindServiceFunSpec(protoService: ProtoService): FunSpec = + FunSpec.builder("bindService") .addModifiers(KModifier.OVERRIDE) - .addParameter("request", requestClassName) - .addParameter( - name = "responseObserver", - type = CommonClassNames.streamObserver.parameterizedBy(responseClassName) - ) - .addCode( - CodeBlock.builder() - .addStatement( - "%T(%T.%N(),responseObserver) {", - CommonClassNames.ServerCalls.serverCallUnary, - protoService.enclosingServiceClassName, - methodDefinitionGetterName - ) - .indent() - .addStatement("%N(request)", functionName) - .unindent() - .addStatement("}") - .build() - ) + .returns(CommonClassNames.grpcServerServiceDefinition) + .addCode(addMethodsToServiceCodeBlock(protoService)) .build() - } - private fun buildServerStreamingDelegate(protoMethod: ProtoMethod): FunSpec = with(protoMethod){ - FunSpec.builder(functionName) - .addModifiers(KModifier.OVERRIDE) - .addParameter("request", requestClassName) - .addParameter( - name = "responseObserver", - type = CommonClassNames.streamObserver.parameterizedBy(responseClassName) - ) - .addCode( - CodeBlock.builder() - .addStatement( - "%T(%T.%N(),responseObserver) { responseChannel: %T ->", - CommonClassNames.ServerCalls.serverCallServerStreaming, - protoService.enclosingServiceClassName, - methodDefinitionGetterName, - CommonClassNames.sendChannel.parameterizedBy(responseClassName) - ) - .indent() - .addStatement("%N(request, responseChannel)", functionName) - .unindent() - .addStatement("}") - .build() - ) - .build() - } + private fun addMethodsToServiceCodeBlock(protoService: ProtoService): CodeBlock { + val builder = CodeBlock.builder() + .addStatement("val builder = %T", CommonClassNames.grpcServerServiceDefinition) + .indent() + .addStatement(".builder(%T.getServiceDescriptor())", protoService.enclosingServiceClassName) + + protoService.methodDefinitions.forEach { method -> + val handlerClassName = when(method.type){ + MethodDescriptor.MethodType.UNARY -> CallHandlerClassNames.unary + MethodDescriptor.MethodType.CLIENT_STREAMING -> CallHandlerClassNames.clientStreaming + MethodDescriptor.MethodType.SERVER_STREAMING -> CallHandlerClassNames.serverStreaming + MethodDescriptor.MethodType.BIDI_STREAMING -> CallHandlerClassNames.bidiStreaming + MethodDescriptor.MethodType.UNKNOWN -> throw IllegalStateException("Unknown method type") + } - private fun buildClientStreamingDelegate(protoMethod: ProtoMethod): FunSpec = with(protoMethod){ - FunSpec.builder(functionName) - .addModifiers(KModifier.OVERRIDE) - .returns(CommonClassNames.streamObserver.parameterizedBy(requestClassName)) - .addParameter( - name ="responseObserver", - type = CommonClassNames.streamObserver.parameterizedBy(responseClassName) - ) - .addCode( - CodeBlock.of( - """ - val requestObserver = %T( - %T.%N(), - responseObserver - ) { requestChannel: %T -> - - %N(requestChannel) - } - return requestObserver %L - """.trimIndent(), - CommonClassNames.ServerCalls.serverCallClientStreaming, - protoService.enclosingServiceClassName, - methodDefinitionGetterName, - CommonClassNames.receiveChannel.parameterizedBy(requestClassName), - functionName, - "\n" - ) - ) - .build() + builder + .addStatement(".addMethod(") + .indent() + .addStatement("${method.methodDefinitionPropName},") + .addStatement("%T(MethodHandlers(this, %N))", + handlerClassName, + method.idPropertyName) + .unindent() + .addStatement(")") + } + + return builder.unindent().addStatement("return builder.build()").build() } - private fun buildBidiStreamingDelegate(protoMethod: ProtoMethod): FunSpec = with(protoMethod){ - FunSpec.builder(functionName) - .addModifiers(KModifier.OVERRIDE) - .returns(CommonClassNames.streamObserver.parameterizedBy(requestClassName)) - .addParameter( - name = "responseObserver", - type = CommonClassNames.streamObserver.parameterizedBy(responseClassName) - ) - .addCode( - CodeBlock.of( - """ - val requestChannel = %T( - %T.%N(), - responseObserver - ) { requestChannel: %T, - responseChannel: %T -> - - %N(requestChannel, responseChannel) - } - return requestChannel %L - """.trimIndent(), - CommonClassNames.ServerCalls.serverCallBidiStreaming, - protoService.enclosingServiceClassName, - methodDefinitionGetterName, - CommonClassNames.receiveChannel.parameterizedBy(requestClassName), - CommonClassNames.sendChannel.parameterizedBy(responseClassName), - functionName, - "\n" - ) - ) - .build() + object CallHandlerClassNames { + val unary = ClassName(CommonPackages.krotoCoroutineLib+".server","unaryServerCallHandler") + val serverStreaming = ClassName(CommonPackages.krotoCoroutineLib+".server","serverStreamingServerCallHandler") + val clientStreaming = ClassName(CommonPackages.krotoCoroutineLib+".server","clientStreamingServerCallHandler") + val bidiStreaming = ClassName(CommonPackages.krotoCoroutineLib+".server","bidiStreamingServerCallHandler") } } \ No newline at end of file diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoMethod.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoMethod.kt index db3c692..5a51480 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoMethod.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoMethod.kt @@ -25,6 +25,7 @@ import io.grpc.MethodDescriptor class ProtoMethod( override val descriptorProto: DescriptorProtos.MethodDescriptorProto, + val index: Int, val sourceLocation: Location, val protoService: ProtoService ) : Schema.DescriptorWrapper { diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoService.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoService.kt index 60d8fcd..bd58b3e 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoService.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/proto/ProtoService.kt @@ -44,7 +44,7 @@ data class ProtoService( descriptorProto.methodList.mapIndexed { methodIndex, methodDescriptor -> val sourceLocation = this@ProtoService.protoFile.descriptorProto.sourceCodeInfo.locationList .findByMethodIndex(methodIndex) - ProtoMethod(methodDescriptor, sourceLocation, this@ProtoService) + ProtoMethod(methodDescriptor, methodIndex, sourceLocation, this@ProtoService) } } diff --git a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/utils/CommonNames.kt b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/utils/CommonNames.kt index 7cd124a..a0e81ea 100644 --- a/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/utils/CommonNames.kt +++ b/protoc-gen-kroto-plus/src/main/kotlin/com/github/marcoferrer/krotoplus/utils/CommonNames.kt @@ -53,13 +53,6 @@ object CommonClassNames{ val grpcStubRpcMethod = ClassName("io.grpc.stub.annotations","RpcMethod") val grpcMethodDescriptor = ClassName("io.grpc","MethodDescriptor") - object GrpcServerCallHandler { - val unary = ClassName("io.grpc.stub","ServerCalls","UnaryMethod") - val serverStreaming = ClassName("io.grpc.stub","ServerCalls","ServerStreamingMethod") - val clientStreaming = ClassName("io.grpc.stub","ServerCalls","ClientStreamingMethod") - val bidiStreaming = ClassName("io.grpc.stub","ServerCalls","BidiStreamingMethod") - } - val streamObserver: ClassName = ClassName("io.grpc.stub", "StreamObserver") val experimentalKrotoPlusCoroutinesApi = ClassName(krotoCoroutineLib, "ExperimentalKrotoPlusCoroutinesApi")