3434import java .util .Arrays ;
3535import java .util .Collection ;
3636import java .util .Collections ;
37+ import java .util .HashSet ;
3738import java .util .List ;
39+ import java .util .Optional ;
3840import java .util .Set ;
3941import java .util .concurrent .CompletableFuture ;
42+ import java .util .concurrent .CountDownLatch ;
4043import java .util .concurrent .ExecutorService ;
4144import java .util .concurrent .Executors ;
4245import java .util .concurrent .TimeUnit ;
46+ import java .util .concurrent .atomic .AtomicInteger ;
4347import java .util .stream .Collectors ;
4448
4549import static java .util .Arrays .asList ;
@@ -53,11 +57,13 @@ class CompletedSnapshotStoreTest {
5357
5458 private ExecutorService executorService ;
5559 private TestCompletedSnapshotHandleStore .Builder builder ;
60+ private TestCompletedSnapshotHandleStore defaultHandleStore ;
5661 private @ TempDir Path tempDir ;
5762
5863 @ BeforeEach
5964 void setup () {
6065 builder = TestCompletedSnapshotHandleStore .newBuilder ();
66+ defaultHandleStore = builder .build ();
6167 executorService = Executors .newFixedThreadPool (2 , new ExecutorThreadFactory ("IO-Executor" ));
6268 }
6369
@@ -171,6 +177,334 @@ void testAddSnapshotFailedShouldNotRemoveOldOnes() {
171177 assertThat (completedSnapshotStore .getLatestSnapshot ().get ().getSnapshotID ()).isOne ();
172178 }
173179
180+ @ Test
181+ void testConcurrentAdds () throws Exception {
182+ final CompletedSnapshotStore completedSnapshotStore =
183+ createCompletedSnapshotStore (10 , defaultHandleStore , Collections .emptyList ());
184+
185+ final int numThreads = 10 ;
186+ final int snapshotsPerThread = 5 ;
187+ final ExecutorService testExecutor =
188+ Executors .newFixedThreadPool (
189+ numThreads , new ExecutorThreadFactory ("concurrent-add-thread" ));
190+
191+ try {
192+ CountDownLatch startLatch = new CountDownLatch (1 );
193+ CountDownLatch completionLatch = new CountDownLatch (numThreads );
194+ AtomicInteger exceptionCount = new AtomicInteger (0 );
195+
196+ // Spin up threads to add snapshots concurrently
197+ for (int threadId = 0 ; threadId < numThreads ; threadId ++) {
198+ final int finalThreadId = threadId ;
199+ testExecutor .submit (
200+ () -> {
201+ try {
202+ startLatch .await ();
203+ for (int i = 0 ; i < snapshotsPerThread ; i ++) {
204+ long snapshotId =
205+ (long ) finalThreadId * snapshotsPerThread + i + 1 ;
206+ CompletedSnapshot snapshot = getSnapshot (snapshotId );
207+ completedSnapshotStore .add (snapshot );
208+ }
209+ } catch (Exception e ) {
210+ exceptionCount .incrementAndGet ();
211+ } finally {
212+ completionLatch .countDown ();
213+ }
214+ });
215+ }
216+
217+ // Start all threads simultaneously
218+ startLatch .countDown ();
219+ boolean completed = completionLatch .await (30 , TimeUnit .SECONDS );
220+ assertThat (completed ).as ("All threads should complete" ).isTrue ();
221+
222+ // Ensure time for async cleanup to finish
223+ Thread .sleep (100 );
224+
225+ assertThat (exceptionCount .get ()).as ("No exceptions should occur" ).isEqualTo (0 );
226+
227+ List <CompletedSnapshot > allSnapshots = completedSnapshotStore .getAllSnapshots ();
228+ assertThat (allSnapshots .size ())
229+ .as ("Should retain at most maxNumberOfSnapshotsToRetain snapshots" )
230+ .isLessThanOrEqualTo (10 );
231+
232+ Set <Long > snapshotIds = new HashSet <>();
233+ for (CompletedSnapshot snapshot : allSnapshots ) {
234+ assertThat (snapshotIds .add (snapshot .getSnapshotID ()))
235+ .as ("Snapshot IDs should be unique (no corruption)" )
236+ .isTrue ();
237+ }
238+
239+ long numSnapshots = completedSnapshotStore .getNumSnapshots ();
240+ assertThat (numSnapshots )
241+ .as ("getNumSnapshots() should match getAllSnapshots().size()" )
242+ .isEqualTo (allSnapshots .size ());
243+
244+ if (!allSnapshots .isEmpty ()) {
245+ Optional <CompletedSnapshot > latest = completedSnapshotStore .getLatestSnapshot ();
246+ assertThat (latest ).as ("Latest snapshot should be present" ).isPresent ();
247+ assertThat (latest .get ())
248+ .as ("Latest snapshot should match last in getAllSnapshots()" )
249+ .isEqualTo (allSnapshots .get (allSnapshots .size () - 1 ));
250+ }
251+ } finally {
252+ testExecutor .shutdown ();
253+ }
254+ }
255+
256+ @ Test
257+ void testConcurrentReadsAndWrites () throws Exception {
258+ final CompletedSnapshotStore completedSnapshotStore =
259+ createCompletedSnapshotStore (5 , defaultHandleStore , Collections .emptyList ());
260+
261+ final int numWriterThreads = 5 ;
262+ final int numReaderThreads = 3 ;
263+ final int snapshotsPerWriter = 3 ;
264+ final ExecutorService testExecutor =
265+ Executors .newFixedThreadPool (
266+ numWriterThreads + numReaderThreads ,
267+ new ExecutorThreadFactory ("concurrent-read-thread" ));
268+
269+ try {
270+ CountDownLatch startLatch = new CountDownLatch (1 );
271+ CountDownLatch completionLatch =
272+ new CountDownLatch (numWriterThreads + numReaderThreads );
273+ AtomicInteger exceptionCount = new AtomicInteger (0 );
274+
275+ // Spin up snapshot writer threads
276+ for (int threadId = 0 ; threadId < numWriterThreads ; threadId ++) {
277+ final int finalThreadId = threadId ;
278+ testExecutor .submit (
279+ () -> {
280+ try {
281+ startLatch .await ();
282+ for (int i = 0 ; i < snapshotsPerWriter ; i ++) {
283+ long snapshotId =
284+ (long ) finalThreadId * snapshotsPerWriter + i + 1 ;
285+ CompletedSnapshot snapshot = getSnapshot (snapshotId );
286+ completedSnapshotStore .add (snapshot );
287+ }
288+ } catch (Exception e ) {
289+ exceptionCount .incrementAndGet ();
290+ } finally {
291+ completionLatch .countDown ();
292+ }
293+ });
294+ }
295+
296+ // Spin up snapshot reader threads (during writes)
297+ for (int threadId = 0 ; threadId < numReaderThreads ; threadId ++) {
298+ testExecutor .submit (
299+ () -> {
300+ try {
301+ startLatch .await ();
302+ for (int i = 0 ; i < 50 ; i ++) {
303+ // Read operations
304+ completedSnapshotStore .getNumSnapshots ();
305+ completedSnapshotStore .getAllSnapshots ();
306+ completedSnapshotStore .getLatestSnapshot ();
307+ // Introduce tiny wait to intersperse reads/writes
308+ Thread .sleep (2 );
309+ }
310+ } catch (InterruptedException e ) {
311+ Thread .currentThread ().interrupt ();
312+ exceptionCount .incrementAndGet ();
313+ } catch (Exception e ) {
314+ exceptionCount .incrementAndGet ();
315+ } finally {
316+ completionLatch .countDown ();
317+ }
318+ });
319+ }
320+
321+ // Start all threads simultaneously
322+ startLatch .countDown ();
323+ boolean completed = completionLatch .await (30 , TimeUnit .SECONDS );
324+ assertThat (completed ).as ("All threads should complete" ).isTrue ();
325+
326+ // Ensure time for async cleanup to finish
327+ Thread .sleep (100 );
328+
329+ assertThat (exceptionCount .get ()).as ("No exceptions should occur" ).isEqualTo (0 );
330+
331+ long numSnapshots = completedSnapshotStore .getNumSnapshots ();
332+ List <CompletedSnapshot > allSnapshots = completedSnapshotStore .getAllSnapshots ();
333+
334+ assertThat (numSnapshots )
335+ .as ("getNumSnapshots() should match getAllSnapshots().size()" )
336+ .isEqualTo (allSnapshots .size ());
337+
338+ assertThat (numSnapshots )
339+ .as ("Should retain at most maxNumberOfSnapshotsToRetain snapshots" )
340+ .isLessThanOrEqualTo (5 );
341+
342+ if (!allSnapshots .isEmpty ()) {
343+ Set <Long > snapshotIds = new HashSet <>();
344+ for (CompletedSnapshot snapshot : allSnapshots ) {
345+ assertThat (snapshotIds .add (snapshot .getSnapshotID ()))
346+ .as ("Snapshot IDs should be unique (no corruption)" )
347+ .isTrue ();
348+ }
349+ }
350+
351+ if (!allSnapshots .isEmpty ()) {
352+ Optional <CompletedSnapshot > latest = completedSnapshotStore .getLatestSnapshot ();
353+ assertThat (latest ).as ("Latest snapshot should be present" ).isPresent ();
354+ assertThat (latest .get ())
355+ .as ("Latest snapshot should match last in getAllSnapshots()" )
356+ .isEqualTo (allSnapshots .get (allSnapshots .size () - 1 ));
357+ }
358+ } finally {
359+ testExecutor .shutdown ();
360+ }
361+ }
362+
363+ @ Test
364+ void testConcurrentAddsWithSnapshotRetention () throws Exception {
365+ final int maxRetain = 3 ;
366+ final CompletedSnapshotStore completedSnapshotStore =
367+ createCompletedSnapshotStore (
368+ maxRetain , defaultHandleStore , Collections .emptyList ());
369+
370+ final int numThreads = 5 ;
371+ final int snapshotsPerThread = 3 ;
372+ final ExecutorService testExecutor =
373+ Executors .newFixedThreadPool (
374+ numThreads , new ExecutorThreadFactory ("concurrent-add-retention-thread" ));
375+
376+ try {
377+ CountDownLatch startLatch = new CountDownLatch (1 );
378+ CountDownLatch completionLatch = new CountDownLatch (numThreads );
379+ AtomicInteger exceptionCount = new AtomicInteger (0 );
380+
381+ // Spin up threads to add snapshots concurrently
382+ for (int threadId = 0 ; threadId < numThreads ; threadId ++) {
383+ final int finalThreadId = threadId ;
384+ testExecutor .submit (
385+ () -> {
386+ try {
387+ startLatch .await ();
388+ for (int i = 0 ; i < snapshotsPerThread ; i ++) {
389+ long snapshotId =
390+ (long ) finalThreadId * snapshotsPerThread + i + 1 ;
391+ CompletedSnapshot snapshot = getSnapshot (snapshotId );
392+ completedSnapshotStore .add (snapshot );
393+ }
394+ } catch (Exception e ) {
395+ exceptionCount .incrementAndGet ();
396+ } finally {
397+ completionLatch .countDown ();
398+ }
399+ });
400+ }
401+
402+ // Start all threads simultaneously
403+ startLatch .countDown ();
404+ boolean completed = completionLatch .await (30 , TimeUnit .SECONDS );
405+ assertThat (completed ).as ("All threads should complete" ).isTrue ();
406+
407+ // Ensure time for async cleanup to finish
408+ Thread .sleep (100 );
409+
410+ assertThat (exceptionCount .get ()).as ("No exceptions should occur" ).isEqualTo (0 );
411+
412+ List <CompletedSnapshot > allSnapshots = completedSnapshotStore .getAllSnapshots ();
413+
414+ assertThat (allSnapshots .size ())
415+ .as ("Should retain at most maxNumberOfSnapshotsToRetain snapshots" )
416+ .isLessThanOrEqualTo (maxRetain );
417+
418+ Set <Long > snapshotIds = new HashSet <>();
419+ for (CompletedSnapshot snapshot : allSnapshots ) {
420+ assertThat (snapshotIds .add (snapshot .getSnapshotID ()))
421+ .as ("Snapshot IDs should be unique (no corruption)" )
422+ .isTrue ();
423+ }
424+
425+ long numSnapshots = completedSnapshotStore .getNumSnapshots ();
426+ assertThat (numSnapshots )
427+ .as ("getNumSnapshots() should match getAllSnapshots().size()" )
428+ .isEqualTo (allSnapshots .size ());
429+
430+ if (!allSnapshots .isEmpty ()) {
431+ Optional <CompletedSnapshot > latest = completedSnapshotStore .getLatestSnapshot ();
432+ assertThat (latest ).as ("Latest snapshot should be present" ).isPresent ();
433+ assertThat (latest .get ())
434+ .as ("Latest snapshot should match last in getAllSnapshots()" )
435+ .isEqualTo (allSnapshots .get (allSnapshots .size () - 1 ));
436+ }
437+ } finally {
438+ testExecutor .shutdown ();
439+ }
440+ }
441+
442+ @ Test
443+ void testConcurrentGetNumSnapshotsAccuracy () throws Exception {
444+ final CompletedSnapshotStore completedSnapshotStore =
445+ createCompletedSnapshotStore (10 , defaultHandleStore , Collections .emptyList ());
446+
447+ final int numOperations = 30 ;
448+ final ExecutorService testExecutor =
449+ Executors .newFixedThreadPool (
450+ 10 , new ExecutorThreadFactory ("concurrent-read-thread" ));
451+
452+ try {
453+ CountDownLatch startLatch = new CountDownLatch (1 );
454+ CountDownLatch completionLatch = new CountDownLatch (numOperations );
455+ AtomicInteger exceptionCount = new AtomicInteger (0 );
456+
457+ // Spin up various different snapshot operations
458+ for (int i = 0 ; i < numOperations ; i ++) {
459+ final int operationId = i ;
460+ testExecutor .submit (
461+ () -> {
462+ try {
463+ startLatch .await ();
464+ if (operationId % 2 == 0 ) {
465+ // Add snapshot
466+ CompletedSnapshot snapshot = getSnapshot (operationId + 1 );
467+ completedSnapshotStore .add (snapshot );
468+ } else {
469+ // Read reapshot
470+ long numSnapshots = completedSnapshotStore .getNumSnapshots ();
471+ List <CompletedSnapshot > allSnapshots =
472+ completedSnapshotStore .getAllSnapshots ();
473+ assertThat (numSnapshots )
474+ .as (
475+ "getNumSnapshots() should match getAllSnapshots().size()" )
476+ .isEqualTo (allSnapshots .size ());
477+ }
478+ } catch (AssertionError e ) {
479+ throw e ;
480+ } catch (Exception e ) {
481+ exceptionCount .incrementAndGet ();
482+ } finally {
483+ completionLatch .countDown ();
484+ }
485+ });
486+ }
487+
488+ // Start all operations simultaneously
489+ startLatch .countDown ();
490+ boolean completed = completionLatch .await (30 , TimeUnit .SECONDS );
491+ assertThat (completed ).as ("All operations should complete" ).isTrue ();
492+
493+ // Ensure time for async cleanup to finish
494+ Thread .sleep (100 );
495+
496+ assertThat (exceptionCount .get ()).as ("No exceptions should occur" ).isEqualTo (0 );
497+
498+ long numSnapshots = completedSnapshotStore .getNumSnapshots ();
499+ List <CompletedSnapshot > allSnapshots = completedSnapshotStore .getAllSnapshots ();
500+ assertThat (numSnapshots )
501+ .as ("Final getNumSnapshots() should match getAllSnapshots().size()" )
502+ .isEqualTo (allSnapshots .size ());
503+ } finally {
504+ testExecutor .shutdown ();
505+ }
506+ }
507+
174508 private List <CompletedSnapshot > mapToCompletedSnapshot (
175509 List <Tuple2 <CompletedSnapshotHandle , String >> snapshotHandles ) {
176510 return snapshotHandles .stream ()
0 commit comments