diff --git a/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java b/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java
index 241a851fd3..44362c3055 100644
--- a/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java
+++ b/spring-cloud-gateway-integration-tests/grpc/src/main/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplication.java
@@ -119,6 +119,14 @@ public void hello(HelloRequest request, StreamObserver responseOb
HelloResponse response = HelloResponse.newBuilder().setGreeting(greeting).build();
responseObserver.onNext(response);
+
+ if ("failWithRuntimeExceptionAfterData!".equals(request.getFirstName())) {
+ StatusRuntimeException exception = Status.RESOURCE_EXHAUSTED.withDescription("Too long firstNames?")
+ .asRuntimeException();
+ responseObserver.onError(exception);
+ return;
+ }
+
responseObserver.onCompleted();
}
diff --git a/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java b/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java
index e5a6bbcb43..5d388918c6 100644
--- a/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java
+++ b/spring-cloud-gateway-integration-tests/grpc/src/test/java/org/springframework/cloud/gateway/tests/grpc/GRPCApplicationTests.java
@@ -34,6 +34,7 @@
import org.springframework.boot.test.web.server.LocalServerPort;
import static io.grpc.Status.FAILED_PRECONDITION;
+import static io.grpc.Status.RESOURCE_EXHAUSTED;
import static io.grpc.netty.NegotiationType.TLS;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
@@ -75,15 +76,35 @@ private ManagedChannel createSecuredChannel(int port) throws SSLException {
@Test
public void gRPCUnaryCallShouldHandleRuntimeException() throws SSLException {
ManagedChannel channel = createSecuredChannel(gatewayPort);
+ boolean thrown = false;
try {
HelloServiceGrpc.newBlockingStub(channel)
.hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeException!").build());
}
catch (StatusRuntimeException e) {
- Assertions.assertThat(FAILED_PRECONDITION.getCode()).isEqualTo(e.getStatus().getCode());
- Assertions.assertThat("Invalid firstName").isEqualTo(e.getStatus().getDescription());
+ thrown = true;
+ Assertions.assertThat(e.getStatus().getCode()).isEqualTo(FAILED_PRECONDITION.getCode());
+ Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Invalid firstName");
}
+ Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue();
+ }
+
+ @Test
+ public void gRPCUnaryCallShouldHandleRuntimeException2() throws SSLException {
+ ManagedChannel channel = createSecuredChannel(gatewayPort);
+ boolean thrown = false;
+ try {
+ HelloServiceGrpc.newBlockingStub(channel)
+ .hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeExceptionAfterData!").build())
+ .getGreeting();
+ }
+ catch (StatusRuntimeException e) {
+ thrown = true;
+ Assertions.assertThat(e.getStatus().getCode()).isEqualTo(RESOURCE_EXHAUSTED.getCode());
+ Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Too long firstNames?");
+ }
+ Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue();
}
private TrustManager[] createTrustAllTrustManager() {
diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java
index 25a45d29ae..0899907076 100644
--- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java
+++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/filter/headers/GRPCResponseHeadersFilter.java
@@ -16,6 +16,7 @@
package org.springframework.cloud.gateway.filter.headers;
+import reactor.netty.http.client.HttpClientResponse;
import reactor.netty.http.server.HttpServerResponse;
import org.springframework.core.Ordered;
@@ -26,6 +27,8 @@
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
+import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR;
+
/**
* @author Alberto C. RĂos
*/
@@ -37,45 +40,62 @@ public class GRPCResponseHeadersFilter implements HttpHeadersFilter, Ordered {
@Override
public HttpHeaders filter(HttpHeaders headers, ServerWebExchange exchange) {
- ServerHttpResponse response = exchange.getResponse();
- HttpHeaders responseHeaders = response.getHeaders();
if (isGRPC(exchange)) {
- String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER;
- String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER);
- if (originalTrailerHeaderValue != null) {
- trailerHeaderValue += "," + originalTrailerHeaderValue;
- }
- responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue);
+ ServerHttpResponse response = exchange.getResponse();
+ HttpHeaders responseHeaders = response.getHeaders();
- while (response instanceof ServerHttpResponseDecorator) {
- response = ((ServerHttpResponseDecorator) response).getDelegate();
+ if (headers.containsKey(GRPC_STATUS_HEADER)) {
+ if (!"0".equals(headers.getFirst(GRPC_STATUS_HEADER))) {
+ response.setComplete(); // avoid empty DATA frame
+ }
}
- if (response instanceof AbstractServerHttpResponse) {
- String grpcStatus = getGrpcStatus(headers);
- String grpcMessage = getGrpcMessage(headers);
- ((HttpServerResponse) ((AbstractServerHttpResponse) response).getNativeResponse()).trailerHeaders(h -> {
- h.set(GRPC_STATUS_HEADER, grpcStatus);
- h.set(GRPC_MESSAGE_HEADER, grpcMessage);
+
+ HttpClientResponse nettyInResponse = exchange.getAttribute(CLIENT_RESPONSE_ATTR);
+ if (nettyInResponse != null) {
+ nettyInResponse.trailerHeaders().subscribe(entries -> {
+ if (entries.contains(GRPC_STATUS_HEADER)) {
+ addTrailingHeader(entries, response, responseHeaders);
+ }
});
}
-
}
+
return headers;
}
- private boolean isGRPC(ServerWebExchange exchange) {
- String contentTypeValue = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
- return StringUtils.startsWithIgnoreCase(contentTypeValue, "application/grpc");
+ private void addTrailingHeader(io.netty.handler.codec.http.HttpHeaders sourceHeaders, ServerHttpResponse response,
+ HttpHeaders responseHeaders) {
+ String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER;
+ String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER);
+ if (originalTrailerHeaderValue != null) {
+ trailerHeaderValue += "," + originalTrailerHeaderValue;
+ }
+ responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue);
+
+ HttpServerResponse nettyOutResponse = getNettyResponse(response);
+ if (nettyOutResponse != null) {
+ String grpcStatus = sourceHeaders.get(GRPC_STATUS_HEADER, "0");
+ String grpcMessage = sourceHeaders.get(GRPC_MESSAGE_HEADER, "");
+ nettyOutResponse.trailerHeaders(h -> {
+ h.set(GRPC_STATUS_HEADER, grpcStatus);
+ h.set(GRPC_MESSAGE_HEADER, grpcMessage);
+ });
+ }
}
- private String getGrpcStatus(HttpHeaders headers) {
- final String grpcStatusValue = headers.getFirst(GRPC_STATUS_HEADER);
- return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : "0";
+ private HttpServerResponse getNettyResponse(ServerHttpResponse response) {
+ while (response instanceof ServerHttpResponseDecorator) {
+ response = ((ServerHttpResponseDecorator) response).getDelegate();
+ }
+ if (response instanceof AbstractServerHttpResponse) {
+ return ((AbstractServerHttpResponse) response).getNativeResponse();
+ }
+ return null;
}
- private String getGrpcMessage(HttpHeaders headers) {
- final String grpcStatusValue = headers.getFirst(GRPC_MESSAGE_HEADER);
- return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : "";
+ private boolean isGRPC(ServerWebExchange exchange) {
+ String contentTypeValue = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
+ return StringUtils.startsWithIgnoreCase(contentTypeValue, "application/grpc");
}
@Override