diff --git a/src/main/java/dev/zarr/zarrjava/v3/Array.java b/src/main/java/dev/zarr/zarrjava/v3/Array.java index 7fbd17d..0d240be 100644 --- a/src/main/java/dev/zarr/zarrjava/v3/Array.java +++ b/src/main/java/dev/zarr/zarrjava/v3/Array.java @@ -14,6 +14,7 @@ import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nonnull; import javax.annotation.Nullable; import ucar.ma2.InvalidRangeException; @@ -119,6 +120,7 @@ public static ArrayMetadataBuilder metadataBuilder(ArrayMetadata existingMetadat /** * Reads the entire Zarr array into an ucar.ma2.Array. + * Utilizes no parallelism. * * @throws ZarrException */ @@ -129,6 +131,7 @@ public ucar.ma2.Array read() throws ZarrException { /** * Reads a part of the Zarr array based on a requested offset and shape into an ucar.ma2.Array. + * Utilizes no parallelism. * * @param offset * @param shape @@ -136,6 +139,30 @@ public ucar.ma2.Array read() throws ZarrException { */ @Nonnull public ucar.ma2.Array read(final long[] offset, final int[] shape) throws ZarrException { + return read(offset, shape, false); + } + + /** + * Reads the entire Zarr array into an ucar.ma2.Array. + * + * @param parallel + * @throws ZarrException + */ + @Nonnull + public ucar.ma2.Array read(final boolean parallel) throws ZarrException { + return read(new long[metadata.ndim()], Utils.toIntArray(metadata.shape), parallel); + } + + /** + * Reads a part of the Zarr array based on a requested offset and shape into an ucar.ma2.Array. + * + * @param offset + * @param shape + * @param parallel + * @throws ZarrException + */ + @Nonnull + public ucar.ma2.Array read(final long[] offset, final int[] shape, final boolean parallel) throws ZarrException { if (offset.length != metadata.ndim()) { throw new IllegalArgumentException("'offset' needs to have rank '" + metadata.ndim() + "'."); } @@ -155,43 +182,46 @@ public ucar.ma2.Array read(final long[] offset, final int[] shape) throws ZarrEx final ucar.ma2.Array outputArray = ucar.ma2.Array.factory(metadata.dataType.getMA2DataType(), shape); - Arrays.stream(IndexingUtils.computeChunkCoords(metadata.shape, chunkShape, offset, shape)) - .forEach( - chunkCoords -> { - try { - final IndexingUtils.ChunkProjection chunkProjection = - IndexingUtils.computeProjection(chunkCoords, metadata.shape, chunkShape, offset, - shape - ); - - if (chunkIsInArray(chunkCoords)) { - MultiArrayUtils.copyRegion(metadata.allocateFillValueChunk(), - chunkProjection.chunkOffset, outputArray, chunkProjection.outOffset, - chunkProjection.shape - ); - } - - final String[] chunkKeys = metadata.chunkKeyEncoding.encodeChunkKey(chunkCoords); - final StoreHandle chunkHandle = storeHandle.resolve(chunkKeys); - if (!chunkHandle.exists()) { - return; - } - if (codecPipeline.supportsPartialDecode()) { - final ucar.ma2.Array chunkArray = codecPipeline.decodePartial(chunkHandle, - Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape); - MultiArrayUtils.copyRegion(chunkArray, new int[metadata.ndim()], outputArray, - chunkProjection.outOffset, chunkProjection.shape - ); - } else { - MultiArrayUtils.copyRegion(readChunk(chunkCoords), chunkProjection.chunkOffset, - outputArray, chunkProjection.outOffset, chunkProjection.shape - ); - } - - } catch (ZarrException e) { - throw new RuntimeException(e); - } - }); + Stream chunkStream = Arrays.stream(IndexingUtils.computeChunkCoords(metadata.shape, chunkShape, offset, shape)); + if (parallel) { + chunkStream = chunkStream.parallel(); + } + chunkStream.forEach( + chunkCoords -> { + try { + final IndexingUtils.ChunkProjection chunkProjection = + IndexingUtils.computeProjection(chunkCoords, metadata.shape, chunkShape, offset, + shape + ); + + if (chunkIsInArray(chunkCoords)) { + MultiArrayUtils.copyRegion(metadata.allocateFillValueChunk(), + chunkProjection.chunkOffset, outputArray, chunkProjection.outOffset, + chunkProjection.shape + ); + } + + final String[] chunkKeys = metadata.chunkKeyEncoding.encodeChunkKey(chunkCoords); + final StoreHandle chunkHandle = storeHandle.resolve(chunkKeys); + if (!chunkHandle.exists()) { + return; + } + if (codecPipeline.supportsPartialDecode()) { + final ucar.ma2.Array chunkArray = codecPipeline.decodePartial(chunkHandle, + Utils.toLongArray(chunkProjection.chunkOffset), chunkProjection.shape); + MultiArrayUtils.copyRegion(chunkArray, new int[metadata.ndim()], outputArray, + chunkProjection.outOffset, chunkProjection.shape + ); + } else { + MultiArrayUtils.copyRegion(readChunk(chunkCoords), chunkProjection.chunkOffset, + outputArray, chunkProjection.outOffset, chunkProjection.shape + ); + } + + } catch (ZarrException e) { + throw new RuntimeException(e); + } + }); return outputArray; } @@ -235,6 +265,7 @@ public ucar.ma2.Array readChunk(long[] chunkCoords) /** * Writes a ucar.ma2.Array into the Zarr array at the beginning of the Zarr array. The shape of * the Zarr array needs be large enough for the write. + * Utilizes no parallelism. * * @param array */ @@ -245,11 +276,37 @@ public void write(ucar.ma2.Array array) { /** * Writes a ucar.ma2.Array into the Zarr array at a specified offset. The shape of the Zarr array * needs be large enough for the write. + * Utilizes no parallelism. * * @param offset * @param array */ public void write(long[] offset, ucar.ma2.Array array) { + write(offset, array, false); + } + + /** + * Writes a ucar.ma2.Array into the Zarr array at the beginning of the Zarr array. The shape of + * the Zarr array needs be large enough for the write. + * + * @param array + * @param parallel + */ + public void write(ucar.ma2.Array array, boolean parallel) { + write(new long[metadata.ndim()], array, parallel); + } + + + + /** + * Writes a ucar.ma2.Array into the Zarr array at a specified offset. The shape of the Zarr array + * needs be large enough for the write. + * + * @param offset + * @param array + * @param parallel + */ + public void write(long[] offset, ucar.ma2.Array array, boolean parallel) { if (offset.length != metadata.ndim()) { throw new IllegalArgumentException("'offset' needs to have rank '" + metadata.ndim() + "'."); } @@ -260,34 +317,37 @@ public void write(long[] offset, ucar.ma2.Array array) { int[] shape = array.getShape(); final int[] chunkShape = metadata.chunkShape(); - Arrays.stream(IndexingUtils.computeChunkCoords(metadata.shape, chunkShape, offset, shape)) - .forEach( - chunkCoords -> { - try { - final IndexingUtils.ChunkProjection chunkProjection = - IndexingUtils.computeProjection(chunkCoords, metadata.shape, chunkShape, offset, - shape - ); - - ucar.ma2.Array chunkArray; - if (IndexingUtils.isFullChunk(chunkProjection.chunkOffset, chunkProjection.shape, - chunkShape - )) { - chunkArray = array.sectionNoReduce(chunkProjection.outOffset, - chunkProjection.shape, - null - ); - } else { - chunkArray = readChunk(chunkCoords); - MultiArrayUtils.copyRegion(array, chunkProjection.outOffset, chunkArray, - chunkProjection.chunkOffset, chunkProjection.shape - ); - } - writeChunk(chunkCoords, chunkArray); - } catch (ZarrException | InvalidRangeException e) { - throw new RuntimeException(e); - } - }); + Stream chunkStream = Arrays.stream(IndexingUtils.computeChunkCoords(metadata.shape, chunkShape, offset, shape)); + if(parallel) { + chunkStream = chunkStream.parallel(); + } + chunkStream.forEach( + chunkCoords -> { + try { + final IndexingUtils.ChunkProjection chunkProjection = + IndexingUtils.computeProjection(chunkCoords, metadata.shape, chunkShape, offset, + shape + ); + + ucar.ma2.Array chunkArray; + if (IndexingUtils.isFullChunk(chunkProjection.chunkOffset, chunkProjection.shape, + chunkShape + )) { + chunkArray = array.sectionNoReduce(chunkProjection.outOffset, + chunkProjection.shape, + null + ); + } else { + chunkArray = readChunk(chunkCoords); + MultiArrayUtils.copyRegion(array, chunkProjection.outOffset, chunkArray, + chunkProjection.chunkOffset, chunkProjection.shape + ); + } + writeChunk(chunkCoords, chunkArray); + } catch (ZarrException | InvalidRangeException e) { + throw new RuntimeException(e); + } + }); } /** @@ -434,6 +494,5 @@ public void write(@Nonnull ucar.ma2.Array content) throws ZarrException { } array.write(offset, content); } - } } diff --git a/src/test/java/dev/zarr/zarrjava/ZarrTest.java b/src/test/java/dev/zarr/zarrjava/ZarrTest.java index 7d980ab..8b0b982 100644 --- a/src/test/java/dev/zarr/zarrjava/ZarrTest.java +++ b/src/test/java/dev/zarr/zarrjava/ZarrTest.java @@ -210,7 +210,7 @@ public void testWriteReadWithZarrita(String codec, String codecParam) throws Exc Assertions.assertArrayEquals(new int[]{2, 4, 8}, readArray.metadata.chunkShape()); Assertions.assertEquals("test_value", readArray.metadata.attributes.get("test_key")); - Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT)); + Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.UINT)); //read in zarrita String command = pythonPath(); @@ -274,7 +274,7 @@ public void testLargerChunkSizeThanArraySize() throws ZarrException, IOException Array readArray = Array.open(storeHandle); ucar.ma2.Array result = readArray.read(); - Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT)); + Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.UINT)); } static Stream invalidChunkSizes() { @@ -346,7 +346,7 @@ public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrExce Array readArray = Array.open(storeHandle); ucar.ma2.Array result = readArray.read(); - Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.INT)); + Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.UINT)); } @Test @@ -631,4 +631,28 @@ public void testReadL4Sample(String mag) throws IOException, ZarrException { assert MultiArrayUtils.allValuesEqual(httpData2, localData2); } + + @ParameterizedTest + @ValueSource(booleans = {false,true}) + public void testParallel(boolean useParallel) throws IOException, ZarrException { + int[] testData = new int[512 * 512 * 512]; + Arrays.setAll(testData, p -> p); + + StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("testParallelRead"); + ArrayMetadata metadata = Array.metadataBuilder() + .withShape(512, 512, 512) + .withDataType(DataType.UINT32) + .withChunkShape(100, 100, 100) + .withFillValue(0) + .build(); + Array writeArray = Array.create(storeHandle, metadata); + writeArray.write(ucar.ma2.Array.factory(ucar.ma2.DataType.UINT, new int[]{512, 512, 512}, testData), useParallel); + + Array readArray = Array.open(storeHandle); + ucar.ma2.Array result = readArray.read(useParallel); + + Assertions.assertArrayEquals(testData, (int[]) result.get1DJavaArray(ucar.ma2.DataType.UINT)); + clearTestoutputFolder(); + } } +