Skip to content

Commit 149ff93

Browse files
Remove reference counting from InboundMessage and make it Releasable (#126138)
There is no actual need to reference-count InboundMessage instances. Their lifecycle is completely linear and we can simplify it away. This saves a little work directly but more importantly, it enables more eager releasing of the underlying buffers in a follow-up. --------- Co-authored-by: David Turner <[email protected]>
1 parent a72883e commit 149ff93

File tree

7 files changed

+85
-42
lines changed

7 files changed

+85
-42
lines changed

server/src/main/java/org/elasticsearch/transport/InboundAggregator.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ public InboundMessage finishAggregation() throws IOException {
120120
checkBreaker(aggregated.getHeader(), aggregated.getContentLength(), breakerControl);
121121
}
122122
if (isShortCircuited()) {
123-
aggregated.decRef();
123+
aggregated.close();
124124
success = true;
125125
return new InboundMessage(aggregated.getHeader(), aggregationException);
126126
} else {
@@ -131,7 +131,7 @@ public InboundMessage finishAggregation() throws IOException {
131131
} finally {
132132
resetCurrentAggregation();
133133
if (success == false) {
134-
aggregated.decRef();
134+
aggregated.close();
135135
}
136136
}
137137
}

server/src/main/java/org/elasticsearch/transport/InboundHandler.java

+50-24
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
2323
import org.elasticsearch.common.util.concurrent.EsExecutors;
2424
import org.elasticsearch.common.util.concurrent.ThreadContext;
25-
import org.elasticsearch.core.Releasable;
2625
import org.elasticsearch.core.Releasables;
2726
import org.elasticsearch.core.TimeValue;
2827
import org.elasticsearch.threadpool.ThreadPool;
@@ -87,21 +86,31 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) {
8786
this.slowLogThresholdMs = slowLogThreshold.getMillis();
8887
}
8988

89+
/**
90+
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
91+
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
92+
* the message themselves otherwise
93+
*/
9094
void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception {
9195
final long startTime = threadPool.rawRelativeTimeInMillis();
9296
channel.getChannelStats().markAccessed(startTime);
9397
TransportLogger.logInboundMessage(channel, message);
9498

9599
if (message.isPing()) {
96-
keepAlive.receiveKeepAlive(channel);
100+
keepAlive.receiveKeepAlive(channel); // pings hold no resources, no need to close
97101
} else {
98-
messageReceived(channel, message, startTime);
102+
messageReceived(channel, /* autocloses absent exception */ message, startTime);
99103
}
100104
}
101105

102106
// Empty stream constant to avoid instantiating a new stream for empty messages.
103107
private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES));
104108

109+
/**
110+
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
111+
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
112+
* the message themselves otherwise
113+
*/
105114
private void messageReceived(TcpChannel channel, InboundMessage message, long startTime) throws IOException {
106115
final InetSocketAddress remoteAddress = channel.getRemoteAddress();
107116
final Header header = message.getHeader();
@@ -115,14 +124,16 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
115124
threadContext.setHeaders(header.getHeaders());
116125
threadContext.putTransient("_remote_address", remoteAddress);
117126
if (header.isRequest()) {
118-
handleRequest(channel, message);
127+
handleRequest(channel, /* autocloses absent exception */ message);
119128
} else {
120129
// Responses do not support short circuiting currently
121130
assert message.isShortCircuit() == false;
122131
responseHandler = findResponseHandler(header);
123132
// ignore if its null, the service logs it
124133
if (responseHandler != null) {
125-
executeResponseHandler(message, responseHandler, remoteAddress);
134+
executeResponseHandler( /* autocloses absent exception */ message, responseHandler, remoteAddress);
135+
} else {
136+
message.close();
126137
}
127138
}
128139
} finally {
@@ -135,6 +146,11 @@ private void messageReceived(TcpChannel channel, InboundMessage message, long st
135146
}
136147
}
137148

149+
/**
150+
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
151+
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
152+
* the message themselves otherwise
153+
*/
138154
private void executeResponseHandler(
139155
InboundMessage message,
140156
TransportResponseHandler<?> responseHandler,
@@ -145,13 +161,13 @@ private void executeResponseHandler(
145161
final StreamInput streamInput = namedWriteableStream(message.openOrGetStreamInput());
146162
assert assertRemoteVersion(streamInput, header.getVersion());
147163
if (header.isError()) {
148-
handlerResponseError(streamInput, message, responseHandler);
164+
handlerResponseError(streamInput, /* autocloses */ message, responseHandler);
149165
} else {
150-
handleResponse(remoteAddress, streamInput, responseHandler, message);
166+
handleResponse(remoteAddress, streamInput, responseHandler, /* autocloses */ message);
151167
}
152168
} else {
153169
assert header.isError() == false;
154-
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, message);
170+
handleResponse(remoteAddress, EMPTY_STREAM_INPUT, responseHandler, /* autocloses */ message);
155171
}
156172
}
157173

@@ -220,10 +236,15 @@ private void verifyResponseReadFully(Header header, TransportResponseHandler<?>
220236
}
221237
}
222238

239+
/**
240+
* @param message the transport message received, guaranteed to be closed by this method if it returns without exception.
241+
* Callers must ensure that {@code message} is closed if this method throws an exception but must not release
242+
* the message themselves otherwise
243+
*/
223244
private <T extends TransportRequest> void handleRequest(TcpChannel channel, InboundMessage message) throws IOException {
224245
final Header header = message.getHeader();
225246
if (header.isHandshake()) {
226-
handleHandshakeRequest(channel, message);
247+
handleHandshakeRequest(channel, /* autocloses */ message);
227248
return;
228249
}
229250

@@ -243,7 +264,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Inbo
243264
Releasables.assertOnce(message.takeBreakerReleaseControl())
244265
);
245266

246-
try {
267+
try (message) {
247268
messageListener.onRequestReceived(requestId, action);
248269
if (reg != null) {
249270
reg.addRequestStats(header.getNetworkMessageSize() + TcpHeader.BYTES_REQUIRED_FOR_MESSAGE_SIZE);
@@ -331,6 +352,9 @@ public void onAfter() {
331352
}
332353
}
333354

355+
/**
356+
* @param message guaranteed to get closed by this method
357+
*/
334358
private void handleHandshakeRequest(TcpChannel channel, InboundMessage message) throws IOException {
335359
var header = message.getHeader();
336360
assert header.actionName.equals(TransportHandshaker.HANDSHAKE_ACTION_NAME);
@@ -351,7 +375,7 @@ private void handleHandshakeRequest(TcpChannel channel, InboundMessage message)
351375
true,
352376
Releasables.assertOnce(message.takeBreakerReleaseControl())
353377
);
354-
try {
378+
try (message) {
355379
handshaker.handleHandshake(transportChannel, requestId, stream);
356380
} catch (Exception e) {
357381
logger.warn(
@@ -371,29 +395,30 @@ private static void sendErrorResponse(String actionName, TransportChannel transp
371395
}
372396
}
373397

398+
/**
399+
* @param message guaranteed to get closed by this method
400+
*/
374401
private <T extends TransportResponse> void handleResponse(
375402
InetSocketAddress remoteAddress,
376403
final StreamInput stream,
377404
final TransportResponseHandler<T> handler,
378-
final InboundMessage inboundMessage
405+
final InboundMessage message
379406
) {
380407
final var executor = handler.executor();
381408
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
382409
// no need to provide a buffer release here, we never escape the buffer when handling directly
383-
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
410+
doHandleResponse(handler, remoteAddress, stream, /* autocloses */ message);
384411
} else {
385-
inboundMessage.mustIncRef();
386412
// release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
387-
final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
388413
executor.execute(new ForkingResponseHandlerRunnable(handler, null) {
389414
@Override
390415
protected void doRun() {
391-
doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), releaseBuffer);
416+
doHandleResponse(handler, remoteAddress, stream, /* autocloses */ message);
392417
}
393418

394419
@Override
395420
public void onAfter() {
396-
Releasables.closeExpectNoException(releaseBuffer);
421+
message.close();
397422
}
398423
});
399424
}
@@ -404,20 +429,19 @@ public void onAfter() {
404429
* @param handler response handler
405430
* @param remoteAddress remote address that the message was sent from
406431
* @param stream bytes stream for reading the message
407-
* @param header message header
408-
* @param releaseResponseBuffer releasable that will be released once the message has been read from the {@code stream}
432+
* @param inboundMessage inbound message, guaranteed to get closed by this method
409433
* @param <T> response message type
410434
*/
411435
private <T extends TransportResponse> void doHandleResponse(
412436
TransportResponseHandler<T> handler,
413437
InetSocketAddress remoteAddress,
414438
final StreamInput stream,
415-
final Header header,
416-
Releasable releaseResponseBuffer
439+
InboundMessage inboundMessage
417440
) {
418441
final T response;
419-
try (releaseResponseBuffer) {
442+
try (inboundMessage) {
420443
response = handler.read(stream);
444+
verifyResponseReadFully(inboundMessage.getHeader(), handler, stream);
421445
} catch (Exception e) {
422446
final TransportException serializationException = new TransportSerializationException(
423447
"Failed to deserialize response from handler [" + handler + "]",
@@ -429,7 +453,6 @@ private <T extends TransportResponse> void doHandleResponse(
429453
return;
430454
}
431455
try {
432-
verifyResponseReadFully(header, handler, stream);
433456
handler.handleResponse(response);
434457
} catch (Exception e) {
435458
doHandleException(handler, new ResponseHandlerFailureTransportException(e));
@@ -438,9 +461,12 @@ private <T extends TransportResponse> void doHandleResponse(
438461
}
439462
}
440463

464+
/**
465+
* @param message guaranteed to get closed by this method
466+
*/
441467
private void handlerResponseError(StreamInput stream, InboundMessage message, final TransportResponseHandler<?> handler) {
442468
Exception error;
443-
try {
469+
try (message) {
444470
error = stream.readException();
445471
verifyResponseReadFully(message.getHeader(), handler, stream);
446472
} catch (Exception e) {

server/src/main/java/org/elasticsearch/transport/InboundMessage.java

+21-4
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
import org.elasticsearch.ElasticsearchException;
1313
import org.elasticsearch.common.bytes.ReleasableBytesReference;
1414
import org.elasticsearch.common.io.stream.StreamInput;
15-
import org.elasticsearch.core.AbstractRefCounted;
1615
import org.elasticsearch.core.IOUtils;
1716
import org.elasticsearch.core.Releasable;
1817

1918
import java.io.IOException;
19+
import java.lang.invoke.MethodHandles;
20+
import java.lang.invoke.VarHandle;
2021
import java.util.Objects;
2122

22-
public class InboundMessage extends AbstractRefCounted {
23+
public class InboundMessage implements Releasable {
2324

2425
private final Header header;
2526
private final ReleasableBytesReference content;
@@ -28,6 +29,19 @@ public class InboundMessage extends AbstractRefCounted {
2829
private Releasable breakerRelease;
2930
private StreamInput streamInput;
3031

32+
@SuppressWarnings("unused") // updated via CLOSED (and _only_ via CLOSED)
33+
private boolean closed;
34+
35+
private static final VarHandle CLOSED;
36+
37+
static {
38+
try {
39+
CLOSED = MethodHandles.lookup().findVarHandle(InboundMessage.class, "closed", boolean.class);
40+
} catch (Exception e) {
41+
throw new ExceptionInInitializerError(e);
42+
}
43+
}
44+
3145
public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) {
3246
this.header = header;
3347
this.content = content;
@@ -84,7 +98,7 @@ public Releasable takeBreakerReleaseControl() {
8498

8599
public StreamInput openOrGetStreamInput() throws IOException {
86100
assert isPing == false && content != null;
87-
assert hasReferences();
101+
assert (boolean) CLOSED.getAcquire(this) == false;
88102
if (streamInput == null) {
89103
streamInput = content.streamInput();
90104
streamInput.setTransportVersion(header.getVersion());
@@ -98,7 +112,10 @@ public String toString() {
98112
}
99113

100114
@Override
101-
protected void closeInternal() {
115+
public void close() {
116+
if (CLOSED.compareAndSet(this, false, true) == false) {
117+
return;
118+
}
102119
try {
103120
IOUtils.close(streamInput, content, breakerRelease);
104121
} catch (Exception e) {

server/src/main/java/org/elasticsearch/transport/InboundPipeline.java

+2-7
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,8 @@ private void forwardFragment(TcpChannel channel, Object fragment) throws IOExcep
112112
messageHandler.accept(channel, PING_MESSAGE);
113113
} else if (fragment == InboundDecoder.END_CONTENT) {
114114
assert aggregator.isAggregating();
115-
InboundMessage aggregated = aggregator.finishAggregation();
116-
try {
117-
statsTracker.markMessageReceived();
118-
messageHandler.accept(channel, aggregated);
119-
} finally {
120-
aggregated.decRef();
121-
}
115+
statsTracker.markMessageReceived();
116+
messageHandler.accept(channel, /* autocloses */ aggregator.finishAggregation());
122117
} else {
123118
assert aggregator.isAggregating();
124119
assert fragment instanceof ReleasableBytesReference;

server/src/main/java/org/elasticsearch/transport/TcpTransport.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,14 @@ protected void serverAcceptedChannel(TcpChannel channel) {
813813
*/
814814
public void inboundMessage(TcpChannel channel, InboundMessage message) {
815815
try {
816-
inboundHandler.inboundMessage(channel, message);
816+
inboundHandler.inboundMessage(channel, /* autocloses absent exception */ message);
817+
message = null;
817818
} catch (Exception e) {
818819
onException(channel, e);
820+
} finally {
821+
if (message != null) {
822+
message.close();
823+
}
819824
}
820825
}
821826

server/src/test/java/org/elasticsearch/transport/InboundAggregatorTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public void testInboundAggregation() throws IOException {
9595
for (ReleasableBytesReference reference : references) {
9696
assertTrue(reference.hasReferences());
9797
}
98-
aggregated.decRef();
98+
aggregated.close();
9999
for (ReleasableBytesReference reference : references) {
100100
assertFalse(reference.hasReferences());
101101
}

server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void testPipelineHandling() throws IOException {
5050
final List<Tuple<MessageData, Exception>> actual = new ArrayList<>();
5151
final List<ReleasableBytesReference> toRelease = new ArrayList<>();
5252
final BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {
53-
try {
53+
try (m) {
5454
final Header header = m.getHeader();
5555
final MessageData actualData;
5656
final TransportVersion version = header.getVersion();
@@ -204,7 +204,7 @@ private static Compression.Scheme getCompressionScheme() {
204204
}
205205

206206
public void testDecodeExceptionIsPropagated() throws IOException {
207-
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
207+
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
208208
final StatsTracker statsTracker = new StatsTracker();
209209
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
210210
final InboundDecoder decoder = new InboundDecoder(recycler);
@@ -245,7 +245,7 @@ public void testDecodeExceptionIsPropagated() throws IOException {
245245
}
246246

247247
public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
248-
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
248+
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> m.close();
249249
final StatsTracker statsTracker = new StatsTracker();
250250
final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime());
251251
final InboundDecoder decoder = new InboundDecoder(recycler);

0 commit comments

Comments
 (0)