44package org .lfdecentralizedtrust .splice .scan .store .bulk
55
66import com .github .luben .zstd .ZstdDirectBufferCompressingStreamNoFinalizer
7+ import io .netty .buffer .PooledByteBufAllocator
78import org .apache .pekko .stream .{Attributes , FlowShape , Inlet , Outlet }
89import org .apache .pekko .stream .stage .{GraphStage , GraphStageLogic , InHandler , OutHandler }
910import org .apache .pekko .util .ByteString
1011
11- import java .nio .ByteBuffer
1212import java .util .concurrent .atomic .AtomicReference
1313
1414/** A Pekko GraphStage that zstd-compresses a stream of bytestrings, and splits the output into zstd objects of size (minWeight + delta).
@@ -41,45 +41,44 @@ case class ZstdGroupedWeight(minSize: Long) extends GraphStage[FlowShape[ByteStr
4141 }
4242
4343 class ZSTD (
44- val tmpBuffer : ByteBuffer ,
45- val compressionLevel : Int = 3 ,
44+ val compressionLevel : Int = 3
4645 ) extends AutoCloseable {
4746
47+ val bufferAllocator = PooledByteBufAllocator .DEFAULT
48+ val tmpBuffer = bufferAllocator.directBuffer(zstdTmpBufferSize)
49+ val tmpNioBuffer = tmpBuffer.nioBuffer(0 , tmpBuffer.capacity())
4850 val compressingStream =
49- new ZstdDirectBufferCompressingStreamNoFinalizer (tmpBuffer , compressionLevel)
51+ new ZstdDirectBufferCompressingStreamNoFinalizer (tmpNioBuffer , compressionLevel)
5052
5153 def compress (input : ByteString ): ByteString = {
52- // TODO(#3429): use a buffer pool to avoid allocating a new ByteBuffer for each compress call
53- val inputBB = ByteBuffer .allocateDirect(input.size)
54- inputBB.put(input.toArrayUnsafe())
55- inputBB.flip()
56- compressingStream.compress(inputBB)
54+ val inputBB = bufferAllocator.directBuffer(input.size)
55+ inputBB.writeBytes(input.toArrayUnsafe())
56+ compressingStream.compress(inputBB.nioBuffer())
57+ inputBB.release()
5758 compressingStream.flush()
58- tmpBuffer .flip()
59- val result = ByteString .fromByteBuffer(tmpBuffer )
60- tmpBuffer .clear()
59+ tmpNioBuffer .flip()
60+ val result = ByteString .fromByteBuffer(tmpNioBuffer )
61+ tmpNioBuffer .clear()
6162 result
6263 }
6364
6465 def zstdFinish (): ByteString = {
65- tmpBuffer .flip()
66- val result = ByteString .fromByteBuffer(tmpBuffer )
67- tmpBuffer .clear()
66+ tmpNioBuffer .flip()
67+ val result = ByteString .fromByteBuffer(tmpNioBuffer )
68+ tmpNioBuffer .clear()
6869 compressingStream.close()
6970 result
7071 }
7172
7273 override def close (): Unit = {
7374 compressingStream.close()
75+ val _ = tmpBuffer.release()
7476 }
7577 }
7678
7779 override def createLogic (inheritedAttributes : Attributes ): GraphStageLogic =
7880 new GraphStageLogic (shape) with InHandler with OutHandler {
79- // TODO(#3429): consider implementing a pool of tmp buffers to avoid allocating a new one for each stage,
80- // and moving some initialization into preStart(), otherwise we allocate even if the stream never runs or fails before starting.
81- private val tmpBuffer = ByteBuffer .allocateDirect(zstdTmpBufferSize)
82- private val zstd = new AtomicReference [ZSTD ](new ZSTD (tmpBuffer, 3 ))
81+ private val zstd = new AtomicReference [ZSTD ](new ZSTD (3 ))
8382 private val state : AtomicReference [State ] = new AtomicReference [State ](State .empty())
8483
8584 override def postStop (): Unit = {
@@ -90,9 +89,8 @@ case class ZstdGroupedWeight(minSize: Long) extends GraphStage[FlowShape[ByteStr
9089 }
9190
9291 private def reset (): Unit = {
93- tmpBuffer.clear()
9492 zstd.get().close()
95- zstd.set(new ZSTD (tmpBuffer, 3 ))
93+ zstd.set(new ZSTD (3 ))
9694 state.set(State .empty())
9795 }
9896
0 commit comments