Skip to content

Commit cf5a5c8

Browse files
authored
[ci] use netty bufferPools for direct buffers (#3437)
Signed-off-by: Itai Segall <itai.segall@digitalasset.com>
1 parent 9c58bcc commit cf5a5c8

File tree

2 files changed

+33
-30
lines changed

2 files changed

+33
-30
lines changed

apps/scan/src/main/scala/org/lfdecentralizedtrust/splice/scan/store/bulk/ZstdGroupedWeight.scala

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
package org.lfdecentralizedtrust.splice.scan.store.bulk
55

66
import com.github.luben.zstd.ZstdDirectBufferCompressingStreamNoFinalizer
7+
import io.netty.buffer.PooledByteBufAllocator
78
import org.apache.pekko.stream.{Attributes, FlowShape, Inlet, Outlet}
89
import org.apache.pekko.stream.stage.{GraphStage, GraphStageLogic, InHandler, OutHandler}
910
import org.apache.pekko.util.ByteString
1011

11-
import java.nio.ByteBuffer
1212
import java.util.concurrent.atomic.AtomicReference
1313

1414
/** A Pekko GraphStage that zstd-compresses a stream of bytestrings, and splits the output into zstd objects of size (minWeight + delta).
@@ -41,45 +41,44 @@ case class ZstdGroupedWeight(minSize: Long) extends GraphStage[FlowShape[ByteStr
4141
}
4242

4343
class ZSTD(
44-
val tmpBuffer: ByteBuffer,
45-
val compressionLevel: Int = 3,
44+
val compressionLevel: Int = 3
4645
) extends AutoCloseable {
4746

47+
val bufferAllocator = PooledByteBufAllocator.DEFAULT
48+
val tmpBuffer = bufferAllocator.directBuffer(zstdTmpBufferSize)
49+
val tmpNioBuffer = tmpBuffer.nioBuffer(0, tmpBuffer.capacity())
4850
val compressingStream =
49-
new ZstdDirectBufferCompressingStreamNoFinalizer(tmpBuffer, compressionLevel)
51+
new ZstdDirectBufferCompressingStreamNoFinalizer(tmpNioBuffer, compressionLevel)
5052

5153
def compress(input: ByteString): ByteString = {
52-
// TODO(#3429): use a buffer pool to avoid allocating a new ByteBuffer for each compress call
53-
val inputBB = ByteBuffer.allocateDirect(input.size)
54-
inputBB.put(input.toArrayUnsafe())
55-
inputBB.flip()
56-
compressingStream.compress(inputBB)
54+
val inputBB = bufferAllocator.directBuffer(input.size)
55+
inputBB.writeBytes(input.toArrayUnsafe())
56+
compressingStream.compress(inputBB.nioBuffer())
57+
inputBB.release()
5758
compressingStream.flush()
58-
tmpBuffer.flip()
59-
val result = ByteString.fromByteBuffer(tmpBuffer)
60-
tmpBuffer.clear()
59+
tmpNioBuffer.flip()
60+
val result = ByteString.fromByteBuffer(tmpNioBuffer)
61+
tmpNioBuffer.clear()
6162
result
6263
}
6364

6465
def zstdFinish(): ByteString = {
65-
tmpBuffer.flip()
66-
val result = ByteString.fromByteBuffer(tmpBuffer)
67-
tmpBuffer.clear()
66+
tmpNioBuffer.flip()
67+
val result = ByteString.fromByteBuffer(tmpNioBuffer)
68+
tmpNioBuffer.clear()
6869
compressingStream.close()
6970
result
7071
}
7172

7273
override def close(): Unit = {
7374
compressingStream.close()
75+
val _ = tmpBuffer.release()
7476
}
7577
}
7678

7779
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic =
7880
new GraphStageLogic(shape) with InHandler with OutHandler {
79-
// TODO(#3429): consider implementing a pool of tmp buffers to avoid allocating a new one for each stage,
80-
// and moving some initialization into preStart(), otherwise we allocate even if the stream never runs or fails before starting.
81-
private val tmpBuffer = ByteBuffer.allocateDirect(zstdTmpBufferSize)
82-
private val zstd = new AtomicReference[ZSTD](new ZSTD(tmpBuffer, 3))
81+
private val zstd = new AtomicReference[ZSTD](new ZSTD(3))
8382
private val state: AtomicReference[State] = new AtomicReference[State](State.empty())
8483

8584
override def postStop(): Unit = {
@@ -90,9 +89,8 @@ case class ZstdGroupedWeight(minSize: Long) extends GraphStage[FlowShape[ByteStr
9089
}
9190

9291
private def reset(): Unit = {
93-
tmpBuffer.clear()
9492
zstd.get().close()
95-
zstd.set(new ZSTD(tmpBuffer, 3))
93+
zstd.set(new ZSTD(3))
9694
state.set(State.empty())
9795
}
9896

apps/scan/src/test/scala/org/lfdecentralizedtrust/splice/scan/store/AcsSnapshotBulkStorageTest.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import com.digitalasset.canton.tracing.TraceContext
1111
import com.digitalasset.canton.{HasActorSystem, HasExecutionContext}
1212
import com.github.luben.zstd.ZstdDirectBufferDecompressingStream
1313
import com.google.protobuf.ByteString
14+
import io.netty.buffer.PooledByteBufAllocator
1415
import org.lfdecentralizedtrust.splice.scan.admin.http.CompactJsonScanHttpEncodings
1516
import org.lfdecentralizedtrust.splice.scan.store.AcsSnapshotStore.QueryAcsSnapshotResult
1617
import org.lfdecentralizedtrust.splice.scan.store.bulk.{
@@ -19,15 +20,14 @@ import org.lfdecentralizedtrust.splice.scan.store.bulk.{
1920
S3BucketConnection,
2021
S3Config,
2122
}
22-
import org.lfdecentralizedtrust.splice.store.{Limit, StoreTest, HardLimit}
23+
import org.lfdecentralizedtrust.splice.store.{HardLimit, Limit, StoreTest}
2324
import org.lfdecentralizedtrust.splice.store.events.SpliceCreatedEvent
2425
import org.lfdecentralizedtrust.splice.util.{EventId, PackageQualifiedName, ValueJsonCodecCodegen}
2526
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials
2627
import software.amazon.awssdk.regions.Region
2728
import software.amazon.awssdk.services.s3.model.{ListObjectsRequest, S3Object}
2829

2930
import java.net.URI
30-
import java.nio.ByteBuffer
3131
import java.nio.charset.StandardCharsets
3232
import scala.concurrent.Future
3333
import scala.jdk.FutureConverters.*
@@ -113,15 +113,20 @@ class AcsSnapshotBulkStorageTest extends StoreTest with HasExecutionContext with
113113
def readUncompressAndDecode(
114114
s3BucketConnection: S3BucketConnection
115115
)(s3obj: S3Object): Array[httpApi.CreatedEvent] = {
116+
val bufferAllocator = PooledByteBufAllocator.DEFAULT
116117
val compressed = s3BucketConnection.readFullObject(s3obj.key()).futureValue
117-
val compressedDirect = ByteBuffer.allocateDirect(compressed.capacity())
118-
val uncompressed = ByteBuffer.allocateDirect(compressed.capacity() * 200)
119-
compressedDirect.put(compressed)
120-
compressedDirect.flip()
121-
Using(new ZstdDirectBufferDecompressingStream(compressedDirect)) { _.read(uncompressed) }
122-
uncompressed.flip()
123-
val allContractsStr = StandardCharsets.UTF_8.newDecoder().decode(uncompressed).toString
118+
val compressedDirect = bufferAllocator.directBuffer(compressed.capacity())
119+
val uncompressedDirect = bufferAllocator.directBuffer(compressed.capacity() * 200)
120+
val uncompressedNio = uncompressedDirect.nioBuffer(0, uncompressedDirect.capacity())
121+
compressedDirect.writeBytes(compressed)
122+
Using(new ZstdDirectBufferDecompressingStream(compressedDirect.nioBuffer())) {
123+
_.read(uncompressedNio)
124+
}
125+
uncompressedNio.flip()
126+
val allContractsStr = StandardCharsets.UTF_8.newDecoder().decode(uncompressedNio).toString
124127
val allContracts = allContractsStr.split("\n")
128+
compressedDirect.release()
129+
uncompressedDirect.release()
125130
allContracts.map(io.circe.parser.decode[httpApi.CreatedEvent](_).value)
126131
}
127132

0 commit comments

Comments
 (0)