@@ -2,7 +2,10 @@ package grpcclients
22
33import (
44 "context"
5+ "errors"
56 "io"
7+ "slices"
8+ "sync"
69
710 remoteexecution "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2"
811 "github.com/buildbarn/bb-storage/pkg/blobstore"
@@ -11,10 +14,13 @@ import (
1114 "github.com/buildbarn/bb-storage/pkg/digest"
1215 "github.com/buildbarn/bb-storage/pkg/util"
1316 "github.com/google/uuid"
17+ "github.com/klauspost/compress/zstd"
1418
1519 "google.golang.org/genproto/googleapis/bytestream"
1620 "google.golang.org/grpc"
21+ "google.golang.org/grpc/codes"
1722 "google.golang.org/grpc/metadata"
23+ "google.golang.org/grpc/status"
1824)
1925
2026type casBlobAccess struct {
@@ -23,20 +29,29 @@ type casBlobAccess struct {
2329 capabilitiesClient remoteexecution.CapabilitiesClient
2430 uuidGenerator util.UUIDGenerator
2531 readChunkSize int
32+ compressionThresholdBytes int64
33+ supportedCompressors []remoteexecution.Compressor_Value
34+ supportedCompressorsMutex sync.RWMutex
35+ capabilitiesOnce sync.Once
2636}
2737
2838// NewCASBlobAccess creates a BlobAccess handle that relays any requests
2939// to a GRPC service that implements the bytestream.ByteStream and
3040// remoteexecution.ContentAddressableStorage services. Those are the
3141// services that Bazel uses to access blobs stored in the Content
3242// Addressable Storage.
33- func NewCASBlobAccess (client grpc.ClientConnInterface , uuidGenerator util.UUIDGenerator , readChunkSize int ) blobstore.BlobAccess {
43+ //
44+ // If compressionThresholdBytes is > 0, the client will attempt to use
45+ // ZSTD compression for blobs larger than this threshold. The server's
46+ // supported compressors will be checked via GetCapabilities().
47+ func NewCASBlobAccess (client grpc.ClientConnInterface , uuidGenerator util.UUIDGenerator , readChunkSize int , compressionThresholdBytes int64 ) blobstore.BlobAccess {
3448 return & casBlobAccess {
3549 byteStreamClient : bytestream .NewByteStreamClient (client ),
3650 contentAddressableStorageClient : remoteexecution .NewContentAddressableStorageClient (client ),
3751 capabilitiesClient : remoteexecution .NewCapabilitiesClient (client ),
3852 uuidGenerator : uuidGenerator ,
3953 readChunkSize : readChunkSize ,
54+ compressionThresholdBytes : compressionThresholdBytes ,
4055 }
4156}
4257
@@ -62,11 +77,137 @@ func (r *byteStreamChunkReader) Close() {
6277 }
6378}
6479
80+ type zstdByteStreamChunkReader struct {
81+ client bytestream.ByteStream_ReadClient
82+ cancel context.CancelFunc
83+ zstdReader io.ReadCloser
84+ readChunkSize int
85+ wg sync.WaitGroup
86+ }
87+
88+ func (r * zstdByteStreamChunkReader ) Read () ([]byte , error ) {
89+ if r .zstdReader == nil {
90+ pr , pw := io .Pipe ()
91+
92+ r .wg .Add (1 )
93+ go func () {
94+ defer r .wg .Done ()
95+ defer pw .Close ()
96+ for {
97+ chunk , err := r .client .Recv ()
98+ if err != nil {
99+ if err != io .EOF {
100+ pw .CloseWithError (err )
101+ }
102+ return
103+ }
104+ if _ , writeErr := pw .Write (chunk .Data ); writeErr != nil {
105+ pw .CloseWithError (writeErr )
106+ return
107+ }
108+ }
109+ }()
110+
111+ var err error
112+ r .zstdReader , err = util .NewZstdReadCloser (pr , zstd .WithDecoderConcurrency (1 ))
113+ if err != nil {
114+ pr .CloseWithError (err )
115+ return nil , err
116+ }
117+ }
118+
119+ buf := make ([]byte , r .readChunkSize )
120+ n , err := r .zstdReader .Read (buf )
121+ if n > 0 {
122+ if err != nil && err != io .EOF {
123+ err = nil
124+ }
125+ return buf [:n ], err
126+ }
127+ return nil , err
128+ }
129+
130+ func (r * zstdByteStreamChunkReader ) Close () {
131+ if r .zstdReader != nil {
132+ r .zstdReader .Close ()
133+ }
134+ r .cancel ()
135+
136+ // Drain the gRPC stream.
137+ for {
138+ if _ , err := r .client .Recv (); err != nil {
139+ break
140+ }
141+ }
142+ r .wg .Wait ()
143+ }
144+
145+ type zstdByteStreamWriter struct {
146+ client bytestream.ByteStream_WriteClient
147+ resourceName string
148+ writeOffset int64
149+ cancel context.CancelFunc
150+ }
151+
152+ func (w * zstdByteStreamWriter ) Write (p []byte ) (int , error ) {
153+ if err := w .client .Send (& bytestream.WriteRequest {
154+ ResourceName : w .resourceName ,
155+ WriteOffset : w .writeOffset ,
156+ Data : p ,
157+ }); err != nil {
158+ return 0 , err
159+ }
160+ w .writeOffset += int64 (len (p ))
161+ w .resourceName = ""
162+ return len (p ), nil
163+ }
164+
165+ func (w * zstdByteStreamWriter ) Close () error {
166+ if err := w .client .Send (& bytestream.WriteRequest {
167+ ResourceName : w .resourceName ,
168+ WriteOffset : w .writeOffset ,
169+ FinishWrite : true ,
170+ }); err != nil {
171+ w .cancel ()
172+ w .client .CloseAndRecv ()
173+ return err
174+ }
175+ _ , err := w .client .CloseAndRecv ()
176+ w .cancel ()
177+ return err
178+ }
179+
65180const resourceNameHeader = "build.bazel.remote.execution.v2.resource-name"
66181
182+ // shouldUseCompression checks if compression should be used for a blob of the given size.
183+ // It also ensures GetCapabilities has been called to negotiate compression support.
184+ func (ba * casBlobAccess ) shouldUseCompression (ctx context.Context , digest digest.Digest ) bool {
185+ if ba .compressionThresholdBytes <= 0 || digest .GetSizeBytes () < ba .compressionThresholdBytes {
186+ return false
187+ }
188+
189+ // If GetCapabilities fails, fallback to no compression.
190+ ba .capabilitiesOnce .Do (func () {
191+ ba .GetCapabilities (ctx , digest .GetDigestFunction ().GetInstanceName ())
192+ })
193+
194+ ba .supportedCompressorsMutex .RLock ()
195+ supportedCompressors := ba .supportedCompressors
196+ ba .supportedCompressorsMutex .RUnlock ()
197+
198+ return slices .Contains (supportedCompressors , remoteexecution .Compressor_ZSTD )
199+ }
200+
67201func (ba * casBlobAccess ) Get (ctx context.Context , digest digest.Digest ) buffer.Buffer {
202+ useCompression := ba .shouldUseCompression (ctx , digest )
203+
204+ compressor := remoteexecution .Compressor_IDENTITY
205+ if useCompression {
206+ compressor = remoteexecution .Compressor_ZSTD
207+ }
208+
68209 ctxWithCancel , cancel := context .WithCancel (ctx )
69- resourceName := digest .GetByteStreamReadPath (remoteexecution . Compressor_IDENTITY )
210+ resourceName := digest .GetByteStreamReadPath (compressor )
70211 client , err := ba .byteStreamClient .Read (
71212 metadata .AppendToOutgoingContext (ctxWithCancel , resourceNameHeader , resourceName ),
72213 & bytestream.ReadRequest {
@@ -77,6 +218,15 @@ func (ba *casBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.B
77218 cancel ()
78219 return buffer .NewBufferFromError (err )
79220 }
221+
222+ if useCompression {
223+ return buffer .NewCASBufferFromChunkReader (digest , & zstdByteStreamChunkReader {
224+ client : client ,
225+ cancel : cancel ,
226+ readChunkSize : ba .readChunkSize ,
227+ }, buffer .BackendProvided (buffer .Irreparable (digest )))
228+ }
229+
80230 return buffer .NewCASBufferFromChunkReader (digest , & byteStreamChunkReader {
81231 client : client ,
82232 cancel : cancel ,
@@ -89,19 +239,65 @@ func (ba *casBlobAccess) GetFromComposite(ctx context.Context, parentDigest, chi
89239}
90240
91241func (ba * casBlobAccess ) Put (ctx context.Context , digest digest.Digest , b buffer.Buffer ) error {
92- r := b .ToChunkReader (0 , ba .readChunkSize )
93- defer r .Close ()
242+ useCompression := ba .shouldUseCompression (ctx , digest )
243+
244+ compressor := remoteexecution .Compressor_IDENTITY
245+ if useCompression {
246+ compressor = remoteexecution .Compressor_ZSTD
247+ }
94248
95249 ctxWithCancel , cancel := context .WithCancel (ctx )
96- resourceName := digest .GetByteStreamWritePath (uuid .Must (ba .uuidGenerator ()), remoteexecution . Compressor_IDENTITY )
250+ resourceName := digest .GetByteStreamWritePath (uuid .Must (ba .uuidGenerator ()), compressor )
97251 client , err := ba .byteStreamClient .Write (
98252 metadata .AppendToOutgoingContext (ctxWithCancel , resourceNameHeader , resourceName ),
99253 )
100254 if err != nil {
101255 cancel ()
256+ b .Discard ()
102257 return err
103258 }
104259
260+ if useCompression {
261+ byteStreamWriter := & zstdByteStreamWriter {
262+ client : client ,
263+ resourceName : resourceName ,
264+ writeOffset : 0 ,
265+ cancel : cancel ,
266+ }
267+
268+ zstdWriter , err := zstd .NewWriter (byteStreamWriter , zstd .WithEncoderConcurrency (1 ))
269+ if err != nil {
270+ cancel ()
271+ if _ , closeErr := client .CloseAndRecv (); closeErr != nil {
272+ return status .Errorf (codes .Internal , "Failed to close client: %v and create zstd writer: %v" , closeErr , err )
273+ }
274+ return status .Errorf (codes .Internal , "Failed to create zstd writer: %v" , err )
275+ }
276+
277+ if err := b .IntoWriter (zstdWriter ); err != nil {
278+ if zstdCloseErr := zstdWriter .Close (); zstdCloseErr != nil {
279+ err = errors .Join (err , zstdCloseErr )
280+ }
281+ if closeErr := byteStreamWriter .Close (); closeErr != nil {
282+ err = errors .Join (err , closeErr )
283+ }
284+ return err
285+ }
286+
287+ if err := zstdWriter .Close (); err != nil {
288+ if closeErr := byteStreamWriter .Close (); closeErr != nil {
289+ err = errors .Join (err , closeErr )
290+ }
291+ return err
292+ }
293+
294+ return byteStreamWriter .Close ()
295+ }
296+
297+ // Non-compressed path
298+ r := b .ToChunkReader (0 , ba .readChunkSize )
299+ defer r .Close ()
300+
105301 writeOffset := int64 (0 )
106302 for {
107303 if data , err := r .Read (); err == nil {
@@ -140,6 +336,10 @@ func (ba *casBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer
140336}
141337
142338func (ba * casBlobAccess ) FindMissing (ctx context.Context , digests digest.Set ) (digest.Set , error ) {
339+ return findMissingBlobsInternal (ctx , digests , ba .contentAddressableStorageClient )
340+ }
341+
342+ func findMissingBlobsInternal (ctx context.Context , digests digest.Set , cas remoteexecution.ContentAddressableStorageClient ) (digest.Set , error ) {
143343 // Partition all digests by digest function, as the
144344 // FindMissingBlobs() RPC can only process digests for a single
145345 // instance name and digest function.
@@ -157,7 +357,7 @@ func (ba *casBlobAccess) FindMissing(ctx context.Context, digests digest.Set) (d
157357 BlobDigests : blobDigests ,
158358 DigestFunction : digestFunction .GetEnumValue (),
159359 }
160- response , err := ba . contentAddressableStorageClient .FindMissingBlobs (ctx , & request )
360+ response , err := cas .FindMissingBlobs (ctx , & request )
161361 if err != nil {
162362 return digest .EmptySet , err
163363 }
@@ -180,11 +380,17 @@ func (ba *casBlobAccess) GetCapabilities(ctx context.Context, instanceName diges
180380 return nil , err
181381 }
182382
383+ cacheCapabilities := serverCapabilities .CacheCapabilities
384+
385+ // Store supported compressors for compression negotiation
386+ ba .supportedCompressorsMutex .Lock ()
387+ ba .supportedCompressors = cacheCapabilities .SupportedCompressors
388+ ba .supportedCompressorsMutex .Unlock ()
389+
183390 // Only return fields that pertain to the Content Addressable
184391 // Storage. Don't set 'max_batch_total_size_bytes', as we don't
185- // issue batch operations. The same holds for fields related to
186- // compression support.
187- cacheCapabilities := serverCapabilities .CacheCapabilities
392+ // issue batch operations. Don't propagate 'supported_compressors'
393+ // as it would be merged with bb_storage's configuration.
188394 return & remoteexecution.ServerCapabilities {
189395 CacheCapabilities : & remoteexecution.CacheCapabilities {
190396 DigestFunctions : digest .RemoveUnsupportedDigestFunctions (cacheCapabilities .DigestFunctions ),
0 commit comments