diff --git a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java index 9cb53e728a8bd..ce0a490e3a6a1 100644 --- a/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java +++ b/server/src/main/java/org/elasticsearch/common/bytes/ReleasableBytesReference.java @@ -78,6 +78,20 @@ public ReleasableBytesReference retain() { return this; } + public ReleasableBytesReference releasableSlice(int from) { + if (from == 0) { + return this; + } + return new ReleasableBytesReference(delegate.slice(from, length() - from), refCounted); + } + + public ReleasableBytesReference releasableSlice(int from, int length) { + if (from == 0 && length() == length) { + return this; + } + return new ReleasableBytesReference(delegate.slice(from, length), refCounted); + } + public ReleasableBytesReference retainedSlice(int from, int length) { if (from == 0 && length() == length) { return retain(); diff --git a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java index 399614175410d..209738dda662f 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundDecoder.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; +import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.ReleasableBytesReference; import org.elasticsearch.common.io.stream.StreamInput; @@ -18,7 +19,6 @@ import org.elasticsearch.core.Releasables; import java.io.IOException; -import java.util.function.Consumer; public class InboundDecoder implements Releasable { @@ -38,23 +38,31 @@ public InboundDecoder(Version version, Recycler recycler) { this.recycler = recycler; } - public int decode(ReleasableBytesReference reference, Consumer fragmentConsumer) throws IOException { + public int decode( + TcpChannel channel, + ReleasableBytesReference reference, + CheckedBiConsumer fragmentConsumer + ) throws IOException { ensureOpen(); try { - return internalDecode(reference, fragmentConsumer); + return internalDecode(channel, reference, fragmentConsumer); } catch (Exception e) { cleanDecodeState(); throw e; } } - public int internalDecode(ReleasableBytesReference reference, Consumer fragmentConsumer) throws IOException { + private int internalDecode( + TcpChannel channel, + ReleasableBytesReference reference, + CheckedBiConsumer fragmentConsumer + ) throws IOException { if (isOnHeader()) { int messageLength = TcpTransport.readMessageLength(reference); if (messageLength == -1) { return 0; } else if (messageLength == 0) { - fragmentConsumer.accept(PING); + fragmentConsumer.accept(channel, PING); return 6; } else { int headerBytesToRead = headerBytesToRead(reference); @@ -68,10 +76,10 @@ public int internalDecode(ReleasableBytesReference reference, Consumer f if (header.isCompressed()) { isCompressed = true; } - fragmentConsumer.accept(header); + fragmentConsumer.accept(channel, header); if (isDone()) { - finishMessage(fragmentConsumer); + finishMessage(channel, fragmentConsumer); } return headerBytesToRead; } @@ -84,33 +92,39 @@ public int internalDecode(ReleasableBytesReference reference, Consumer f return 0; } else { this.decompressor = decompressor; - fragmentConsumer.accept(this.decompressor.getScheme()); + fragmentConsumer.accept(channel, this.decompressor.getScheme()); } } int remainingToConsume = totalNetworkSize - bytesConsumed; int maxBytesToConsume = Math.min(reference.length(), remainingToConsume); - ReleasableBytesReference retainedContent; - if (maxBytesToConsume == remainingToConsume) { - retainedContent = reference.retainedSlice(0, maxBytesToConsume); - } else { - retainedContent = reference.retain(); - } int bytesConsumedThisDecode = 0; if (decompressor != null) { - bytesConsumedThisDecode += decompress(retainedContent); + bytesConsumedThisDecode += decompress( + maxBytesToConsume == remainingToConsume ? reference.slice(0, maxBytesToConsume) : reference + ); bytesConsumed += bytesConsumedThisDecode; ReleasableBytesReference decompressed; while ((decompressed = decompressor.pollDecompressedPage(isDone())) != null) { - fragmentConsumer.accept(decompressed); + try { + fragmentConsumer.accept(channel, decompressed); + } finally { + decompressed.decRef(); + } } } else { + ReleasableBytesReference contentToConsume; + if (maxBytesToConsume == remainingToConsume) { + contentToConsume = reference.releasableSlice(0, maxBytesToConsume); + } else { + contentToConsume = reference; + } bytesConsumedThisDecode += maxBytesToConsume; bytesConsumed += maxBytesToConsume; - fragmentConsumer.accept(retainedContent); + fragmentConsumer.accept(channel, contentToConsume); } if (isDone()) { - finishMessage(fragmentConsumer); + finishMessage(channel, fragmentConsumer); } return bytesConsumedThisDecode; @@ -123,9 +137,9 @@ public void close() { cleanDecodeState(); } - private void finishMessage(Consumer fragmentConsumer) { + private void finishMessage(TcpChannel channel, CheckedBiConsumer fragmentConsumer) throws IOException { cleanDecodeState(); - fragmentConsumer.accept(END_CONTENT); + fragmentConsumer.accept(channel, END_CONTENT); } private void cleanDecodeState() { @@ -139,10 +153,8 @@ private void cleanDecodeState() { } } - private int decompress(ReleasableBytesReference content) throws IOException { - try (content) { - return decompressor.decompress(content); - } + private int decompress(BytesReference content) throws IOException { + return decompressor.decompress(content); } private boolean isDone() { diff --git a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java index 7a8cc9173abb4..25fb0c145d3e4 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundPipeline.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongSupplier; @@ -27,7 +26,6 @@ public class InboundPipeline implements Releasable { - private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); private final LongSupplier relativeTimeInMillis; @@ -93,99 +91,97 @@ public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException { channel.getChannelStats().markAccessed(relativeTimeInMillis.getAsLong()); statsTracker.markBytesRead(reference.length()); - pending.add(reference.retain()); - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { - continueDecoding = false; - } - } else { - continueDecoding = false; - } - } - } + if (pending.isEmpty() == false) { + // we already have pending bytes, so we queue these bytes after them and then try to decode from the combined message + pending.add(reference.retain()); + doHandleBytesWithPending(channel); + return; + } - if (fragments.isEmpty()) { - continueHandling = false; + while (isClosed == false && reference.length() > 0) { + final int bytesDecoded = decode(channel, reference); + if (bytesDecoded != 0) { + reference = reference.releasableSlice(bytesDecoded); } else { - try { - forwardFragments(channel, fragments); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); - } - } - fragments.clear(); - } + break; } } + // if handling the messages didn't cause the channel to get closed and we did not fully consume the buffer retain it + if (isClosed == false && reference.length() > 0) { + pending.add(reference.retain()); + } } - private void forwardFragments(TcpChannel channel, ArrayList fragments) throws IOException { - for (Object fragment : fragments) { - if (fragment instanceof Header) { - assert aggregator.isAggregating() == false; - aggregator.headerReceived((Header) fragment); - } else if (fragment instanceof Compression.Scheme) { - assert aggregator.isAggregating(); - aggregator.updateCompressionScheme((Compression.Scheme) fragment); - } else if (fragment == InboundDecoder.PING) { - assert aggregator.isAggregating() == false; - messageHandler.accept(channel, PING_MESSAGE); - } else if (fragment == InboundDecoder.END_CONTENT) { - assert aggregator.isAggregating(); - try (InboundMessage aggregated = aggregator.finishAggregation()) { - statsTracker.markMessageReceived(); - messageHandler.accept(channel, aggregated); + private int decode(TcpChannel channel, ReleasableBytesReference reference) throws IOException { + return decoder.decode(channel, reference, this::forwardFragment); + } + + private void doHandleBytesWithPending(TcpChannel channel) throws IOException { + do { + final int bytesDecoded; + if (pending.size() == 1) { + bytesDecoded = decode(channel, pending.peekFirst()); + } else { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + bytesDecoded = decode(channel, toDecode); } + } + if (bytesDecoded != 0 && isClosed == false) { + releasePendingBytes(bytesDecoded); } else { - assert aggregator.isAggregating(); - assert fragment instanceof ReleasableBytesReference; - aggregator.aggregate((ReleasableBytesReference) fragment); + assert isClosed == false || pending.isEmpty() : "pending chunks should be empty if closed but saw [" + pending + "]"; + return; } - } + } while (pending.isEmpty() == false); } - private static boolean endOfMessage(Object fragment) { - return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; + private void forwardFragment(TcpChannel channel, Object fragment) throws IOException { + if (fragment instanceof Header) { + assert aggregator.isAggregating() == false; + aggregator.headerReceived((Header) fragment); + } else if (fragment instanceof Compression.Scheme) { + assert aggregator.isAggregating(); + aggregator.updateCompressionScheme((Compression.Scheme) fragment); + } else if (fragment == InboundDecoder.PING) { + assert aggregator.isAggregating() == false; + messageHandler.accept(channel, PING_MESSAGE); + } else if (fragment == InboundDecoder.END_CONTENT) { + assert aggregator.isAggregating(); + try (InboundMessage aggregated = aggregator.finishAggregation()) { + statsTracker.markMessageReceived(); + messageHandler.accept(channel, aggregated); + } + } else { + assert aggregator.isAggregating(); + assert fragment instanceof ReleasableBytesReference; + aggregator.aggregate((ReleasableBytesReference) fragment); + } } private ReleasableBytesReference getPendingBytes() { - if (pending.size() == 1) { - return pending.peekFirst().retain(); - } else { - final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; - int index = 0; - for (ReleasableBytesReference pendingReference : pending) { - bytesReferences[index] = pendingReference.retain(); - ++index; - } - final Releasable releasable = () -> Releasables.closeExpectNoException(bytesReferences); - return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); + assert pending.size() > 1 : "must use this method with multiple pending references but used with " + pending; + final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; + int index = 0; + for (ReleasableBytesReference pendingReference : pending) { + bytesReferences[index] = pendingReference.retain(); + ++index; } + final Releasable releasable = () -> Releasables.closeExpectNoException(bytesReferences); + return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); } private void releasePendingBytes(int bytesConsumed) { int bytesToRelease = bytesConsumed; while (bytesToRelease != 0) { - try (ReleasableBytesReference reference = pending.pollFirst()) { - assert reference != null; - if (bytesToRelease < reference.length()) { - pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); - bytesToRelease -= bytesToRelease; - } else { - bytesToRelease -= reference.length(); - } + ReleasableBytesReference reference = pending.pollFirst(); + assert reference != null; + if (bytesToRelease < reference.length()) { + pending.addFirst(reference.releasableSlice(bytesToRelease)); + return; + } else { + bytesToRelease -= reference.length(); + reference.decRef(); } } } diff --git a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java index 65e3cb1ad4325..5f79c0d55997c 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundDecoderTests.java @@ -78,7 +78,7 @@ public void testDecode() throws IOException { InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); + int bytesConsumed = decoder.decode(null, releasable1, (c, f) -> fragments.add(f)); assertEquals(totalHeaderSize, bytesConsumed); assertTrue(releasable1.hasReferences()); @@ -100,7 +100,7 @@ public void testDecode() throws IOException { final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed); final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2); - int bytesConsumed2 = decoder.decode(releasable2, fragments::add); + int bytesConsumed2 = decoder.decode(null, releasable2, (c, f) -> fragments.add(f)); assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2); final Object content = fragments.get(0); @@ -109,8 +109,6 @@ public void testDecode() throws IOException { assertEquals(messageBytes, content); // Ref count is incremented since the bytes are forwarded as a fragment assertTrue(releasable2.hasReferences()); - releasable2.decRef(); - assertTrue(releasable2.hasReferences()); assertTrue(releasable2.decRef()); assertEquals(InboundDecoder.END_CONTENT, endMarker); } @@ -142,7 +140,7 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException { InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); + int bytesConsumed = decoder.decode(null, releasable1, (c, f) -> fragments.add(f)); assertEquals(partialHeaderSize, bytesConsumed); assertTrue(releasable1.hasReferences()); @@ -161,14 +159,13 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException { final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed); final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2); - int bytesConsumed2 = decoder.decode(releasable2, fragments::add); + int bytesConsumed2 = decoder.decode(null, releasable2, (c, f) -> fragments.add(f)); if (compressionScheme == null) { assertEquals(2, fragments.size()); } else { assertEquals(3, fragments.size()); final Object body = fragments.get(1); assertThat(body, instanceOf(ReleasableBytesReference.class)); - ((ReleasableBytesReference) body).close(); } assertEquals(InboundDecoder.END_CONTENT, fragments.get(fragments.size() - 1)); assertEquals(totalBytes.length() - bytesConsumed, bytesConsumed2); @@ -199,7 +196,7 @@ public void testDecodeHandshakeCompatibility() throws IOException { InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); + int bytesConsumed = decoder.decode(null, releasable1, (c, f) -> fragments.add(f)); assertEquals(totalHeaderSize, bytesConsumed); assertTrue(releasable1.hasReferences()); @@ -248,7 +245,7 @@ public void testCompressedDecode() throws IOException { InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); + int bytesConsumed = decoder.decode(null, releasable1, (c, f) -> fragments.add(f)); assertEquals(totalHeaderSize, bytesConsumed); assertTrue(releasable1.hasReferences()); @@ -270,7 +267,13 @@ public void testCompressedDecode() throws IOException { final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed); final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2); - int bytesConsumed2 = decoder.decode(releasable2, fragments::add); + int bytesConsumed2 = decoder.decode(null, releasable2, (c, e) -> { + if (e instanceof ReleasableBytesReference) { + fragments.add(((ReleasableBytesReference) e).retain()); + } else { + fragments.add(e); + } + }); assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2); final Object compressionScheme = fragments.get(0); @@ -312,7 +315,7 @@ public void testCompressedDecodeHandshakeCompatibility() throws IOException { InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler); final ArrayList fragments = new ArrayList<>(); final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes); - int bytesConsumed = decoder.decode(releasable1, fragments::add); + int bytesConsumed = decoder.decode(null, releasable1, (c, f) -> fragments.add(f)); assertEquals(totalHeaderSize, bytesConsumed); assertTrue(releasable1.hasReferences()); @@ -350,7 +353,7 @@ public void testVersionIncompatibilityDecodeException() throws IOException { final ArrayList fragments = new ArrayList<>(); try (ReleasableBytesReference r = ReleasableBytesReference.wrap(bytes)) { releasable1 = r; - expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add)); + expectThrows(IllegalStateException.class, () -> decoder.decode(null, releasable1, (c, f) -> fragments.add(f))); } } // No bytes are retained diff --git a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java index 726c876ba5598..9083ee0d606e1 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundPipelineTests.java @@ -297,6 +297,60 @@ public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { } } + public void testDecodeDoesNotRetainOnClosedChannel() throws IOException { + BiConsumer messageHandler = (c, m) -> {}; + final StatsTracker statsTracker = new StatsTracker(); + final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); + final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler); + final Supplier breaker = () -> new NoopCircuitBreaker("test"); + final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); + final InboundPipeline pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, messageHandler); + try (RecyclerBytesStreamOutput streamOutput = new RecyclerBytesStreamOutput(recycler)) { + String actionName = "actionName"; + final Version version = Version.CURRENT; + final String value = randomAlphaOfLength(1000); + final boolean isRequest = randomBoolean(); + final long requestId = randomNonNegativeLong(); + + OutboundMessage message; + if (isRequest) { + message = new OutboundMessage.Request(threadContext, new TestRequest(value), version, actionName, requestId, false, null); + } else { + message = new OutboundMessage.Response(threadContext, new TestResponse(value), version, requestId, false, null); + } + + final BytesReference reference = message.serialize(streamOutput); + final int fixedHeaderSize = TcpHeader.headerSize(Version.CURRENT); + final int variableHeaderSize = reference.getInt(fixedHeaderSize - 4); + final int totalHeaderSize = fixedHeaderSize + variableHeaderSize; + for (int i = 0; i < totalHeaderSize - 1; ++i) { + try (ReleasableBytesReference slice = ReleasableBytesReference.wrap(reference.slice(i, 1))) { + pipeline.handleBytes(new FakeTcpChannel(), slice); + } + } + + final AtomicBoolean bodyPart1Released = new AtomicBoolean(false); + final int from = totalHeaderSize - 1; + final BytesReference partHeaderPartBody = reference.slice(from, reference.length() - from - 1); + try (ReleasableBytesReference slice = new ReleasableBytesReference(partHeaderPartBody, () -> bodyPart1Released.set(true))) { + pipeline.handleBytes(new FakeTcpChannel(), slice); + } + assertFalse(bodyPart1Released.get()); + pipeline.close(); + assertTrue(bodyPart1Released.get()); + final AtomicBoolean bodyPart2Released = new AtomicBoolean(false); + try ( + ReleasableBytesReference slice = new ReleasableBytesReference( + reference.slice(reference.length() - 1, 1), + () -> bodyPart2Released.set(true) + ) + ) { + pipeline.handleBytes(new FakeTcpChannel(), slice); + } + assertTrue(bodyPart2Released.get()); + } + } + private static class MessageData { private final Version version;