4747
4848import static org .apache .ratis .grpc .GrpcUtil .addMethodWithCustomMarshaller ;
4949import static org .apache .ratis .proto .grpc .RaftServerProtocolServiceGrpc .getAppendEntriesMethod ;
50+ import static org .apache .ratis .proto .grpc .RaftServerProtocolServiceGrpc .getInstallSnapshotMethod ;
5051
5152class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase {
5253 public static final Logger LOG = LoggerFactory .getLogger (GrpcServerProtocolService .class );
@@ -59,10 +60,12 @@ private enum BatchLogKey implements BatchLogger.Key {
5960 static class PendingServerRequest <REQUEST > {
6061 private final AtomicReference <ReferenceCountedObject <REQUEST >> requestRef ;
6162 private final CompletableFuture <Void > future = new CompletableFuture <>();
63+ private final String requestString ;
6264
63- PendingServerRequest (ReferenceCountedObject <REQUEST > requestRef ) {
65+ PendingServerRequest (ReferenceCountedObject <REQUEST > requestRef , String requestString ) {
6466 requestRef .retain ();
6567 this .requestRef = new AtomicReference <>(requestRef );
68+ this .requestString = requestString ;
6669 }
6770
6871 REQUEST getRequest () {
@@ -71,6 +74,10 @@ REQUEST getRequest() {
7174 .orElse (null );
7275 }
7376
77+ String getRequestString () {
78+ return requestString ;
79+ }
80+
7481 CompletableFuture <Void > getFuture () {
7582 return future ;
7683 }
@@ -104,8 +111,7 @@ String getName() {
104111
105112 private String getPreviousRequestString () {
106113 return Optional .ofNullable (previousOnNext .get ())
107- .map (PendingServerRequest ::getRequest )
108- .map (this ::requestToString )
114+ .map (PendingServerRequest ::getRequestString )
109115 .orElse (null );
110116 }
111117
@@ -177,7 +183,9 @@ public void onNext(REQUEST request) {
177183 return ;
178184 }
179185
180- final PendingServerRequest <REQUEST > current = new PendingServerRequest <>(requestRef );
186+ final PendingServerRequest <REQUEST > current
187+ = new PendingServerRequest <>(requestRef , requestToString (requestRef .get ()));
188+ current .getFuture ().whenComplete ((r , e ) -> current .release ());
181189 final long callId = getCallId (current .getRequest ());
182190 final boolean isHeartbeat = isHeartbeat (current .getRequest ());
183191 final Optional <PendingServerRequest <REQUEST >> previous = Optional .ofNullable (previousOnNext .getAndSet (current ));
@@ -243,15 +251,23 @@ private void releaseLast() {
243251 private final RaftServer server ;
244252 private final boolean zeroCopyEnabled ;
245253 private final ZeroCopyMessageMarshaller <AppendEntriesRequestProto > zeroCopyRequestMarshaller ;
254+ private final ZeroCopyMessageMarshaller <InstallSnapshotRequestProto > zeroCopyInstallSnapshotMarshaller ;
246255
247256 GrpcServerProtocolService (Supplier <RaftPeerId > idSupplier , RaftServer server , boolean zeroCopyEnabled ,
248257 ZeroCopyMetrics zeroCopyMetrics ) {
249258 this .idSupplier = idSupplier ;
250259 this .server = server ;
251260 this .zeroCopyEnabled = zeroCopyEnabled ;
252261 this .zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller <>(AppendEntriesRequestProto .getDefaultInstance (),
253- zeroCopyMetrics ::onZeroCopyMessage , zeroCopyMetrics ::onNonZeroCopyMessage , zeroCopyMetrics ::onReleasedMessage );
262+ zeroCopyMetrics ::onZeroCopyAppendEntries , zeroCopyMetrics ::onNonZeroCopyMessage ,
263+ zeroCopyMetrics ::onReleasedMessage , zeroCopyMetrics .newMarshallerMetrics ());
264+ this .zeroCopyInstallSnapshotMarshaller = new ZeroCopyMessageMarshaller <>(
265+ InstallSnapshotRequestProto .getDefaultInstance (),
266+ zeroCopyMetrics ::onZeroCopyInstallSnapshot , zeroCopyMetrics ::onNonZeroCopyMessage ,
267+ zeroCopyMetrics ::onReleasedMessage , zeroCopyMetrics .newMarshallerMetrics ());
254268 zeroCopyMetrics .addUnreleased ("server_protocol" , zeroCopyRequestMarshaller ::getUnclosedCount );
269+ zeroCopyMetrics .addUnreleased ("server_protocol_install_snapshot" ,
270+ zeroCopyInstallSnapshotMarshaller ::getUnclosedCount );
255271 }
256272
257273 RaftPeerId getId () {
@@ -268,9 +284,16 @@ ServerServiceDefinition bindServiceWithZeroCopy() {
268284
269285 // Add appendEntries with zero copy marshaller.
270286 addMethodWithCustomMarshaller (orig , builder , getAppendEntriesMethod (), zeroCopyRequestMarshaller );
287+ // Add installSnapshot with zero copy marshaller for zero-copy counters/metrics.
288+ addMethodWithCustomMarshaller (orig , builder , getInstallSnapshotMethod (), zeroCopyInstallSnapshotMarshaller );
271289 // Add remaining methods as is.
290+ final String appendEntriesMethod = getAppendEntriesMethod ().getFullMethodName ();
291+ final String installSnapshotMethod = getInstallSnapshotMethod ().getFullMethodName ();
272292 orig .getMethods ().stream ().filter (
273- x -> !x .getMethodDescriptor ().getFullMethodName ().equals (getAppendEntriesMethod ().getFullMethodName ())
293+ x -> {
294+ final String methodName = x .getMethodDescriptor ().getFullMethodName ();
295+ return !methodName .equals (appendEntriesMethod ) && !methodName .equals (installSnapshotMethod );
296+ }
274297 ).forEach (
275298 builder ::addMethod
276299 );
@@ -365,6 +388,11 @@ CompletableFuture<InstallSnapshotReplyProto> process(InstallSnapshotRequestProto
365388 return CompletableFuture .completedFuture (server .installSnapshot (request ));
366389 }
367390
391+ @ Override
392+ void release (InstallSnapshotRequestProto request ) {
393+ zeroCopyInstallSnapshotMarshaller .release (request );
394+ }
395+
368396 @ Override
369397 long getCallId (InstallSnapshotRequestProto request ) {
370398 return request .getServerRequest ().getCallId ();
0 commit comments