Skip to content

Commit b53e163

Browse files
authored
Merge pull request #6 from scalableminds/check-sharding-bounds
Check sharding bounds
2 parents bdb1bb8 + 7090614 commit b53e163

File tree

3 files changed

+96
-6
lines changed

3 files changed

+96
-6
lines changed

src/main/java/dev/zarr/zarrjava/v3/ArrayMetadata.java

+36
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
import dev.zarr.zarrjava.v3.chunkgrid.RegularChunkGrid;
1111
import dev.zarr.zarrjava.v3.chunkkeyencoding.ChunkKeyEncoding;
1212
import dev.zarr.zarrjava.v3.codec.Codec;
13+
import dev.zarr.zarrjava.v3.codec.core.ShardingIndexedCodec;
14+
1315
import java.nio.ByteBuffer;
1416
import java.util.Arrays;
1517
import java.util.Map;
18+
import java.util.Optional;
1619
import javax.annotation.Nonnull;
1720
import javax.annotation.Nullable;
1821

@@ -91,6 +94,35 @@ public ArrayMetadata(
9194
"Expected node type '" + this.nodeType + "', got '" + nodeType + "'.");
9295
}
9396

97+
if (chunkGrid instanceof RegularChunkGrid) {
98+
int[] chunkShape = ((RegularChunkGrid) chunkGrid).configuration.chunkShape;
99+
if (shape.length != chunkShape.length) {
100+
throw new ZarrException("Shape (ndim=" + shape.length + ") and chunk shape (ndim=" +
101+
chunkShape.length + ") need to have the same number of dimensions.");
102+
}
103+
for (int i = 0; i < shape.length; i++) {
104+
if (shape[i] < chunkShape[i]) {
105+
throw new ZarrException("Shape " + Arrays.toString(shape) + " can not contain chunk shape "
106+
+ Arrays.toString(chunkShape));
107+
}
108+
}
109+
110+
Optional<Codec> shardingCodec = getShardingIndexedCodec(codecs);
111+
int[] outerChunkShape = chunkShape;
112+
while (shardingCodec.isPresent()) {
113+
ShardingIndexedCodec.Configuration shardingConfig = ((ShardingIndexedCodec) shardingCodec.get()).configuration;
114+
int[] innerChunkShape = shardingConfig.chunkShape;
115+
if (outerChunkShape.length != innerChunkShape.length)
116+
throw new ZarrException("Sharding dimensions mismatch of outer chunk shape " + Arrays.toString(outerChunkShape) + " and inner chunk shape" + Arrays.toString(innerChunkShape));
117+
for (int i = 0; i < outerChunkShape.length; i++) {
118+
if (outerChunkShape[i] < innerChunkShape[i])
119+
throw new ZarrException("Sharding outer chunk shape " + Arrays.toString(outerChunkShape) + " can not contain inner chunk shape " + Arrays.toString(innerChunkShape));
120+
}
121+
outerChunkShape = innerChunkShape;
122+
shardingCodec = getShardingIndexedCodec(shardingConfig.codecs);
123+
}
124+
}
125+
94126
this.shape = shape;
95127
this.dataType = dataType;
96128
this.chunkGrid = chunkGrid;
@@ -227,6 +259,10 @@ public int ndim() {
227259
return shape.length;
228260
}
229261

262+
public static Optional<Codec> getShardingIndexedCodec(Codec[] codecs) {
263+
return Arrays.stream(codecs).filter(codec -> codec instanceof ShardingIndexedCodec).findFirst();
264+
}
265+
230266
public int[] chunkShape() {
231267
return ((RegularChunkGrid) this.chunkGrid).configuration.chunkShape;
232268
}

src/main/java/dev/zarr/zarrjava/v3/ArrayMetadataBuilder.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import dev.zarr.zarrjava.v3.codec.CodecBuilder;
1212
import dev.zarr.zarrjava.v3.codec.core.BytesCodec;
1313
import dev.zarr.zarrjava.v3.codec.core.BytesCodec.Endian;
14+
import dev.zarr.zarrjava.v3.codec.core.ShardingIndexedCodec;
15+
16+
import java.util.Arrays;
1417
import java.util.HashMap;
1518
import java.util.Map;
1619
import java.util.function.Function;
@@ -133,12 +136,8 @@ public ArrayMetadata build() throws ZarrException {
133136
if (chunkGrid == null) {
134137
throw new ZarrException("Chunk grid needs to be provided. Please call `.withChunkShape`.");
135138
}
136-
if (chunkGrid instanceof RegularChunkGrid
137-
&& shape.length != ((RegularChunkGrid) chunkGrid).configuration.chunkShape.length) {
138-
throw new ZarrException("Shape (ndim=" + shape.length + ") and chunk shape (ndim=" +
139-
((RegularChunkGrid) chunkGrid).configuration.chunkShape.length +
140-
") need to have the same number of dimensions.");
141-
}
139+
140+
142141
return new ArrayMetadata(shape, dataType, chunkGrid, chunkKeyEncoding, fillValue, codecs,
143142
dimensionNames,
144143
attributes

src/test/java/dev/zarr/zarrjava/ZarrTest.java

+55
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.junit.jupiter.api.Test;
2020
import org.junit.jupiter.params.ParameterizedTest;
2121
import org.junit.jupiter.params.provider.CsvSource;
22+
import org.junit.jupiter.params.provider.MethodSource;
2223
import org.junit.jupiter.params.provider.ValueSource;
2324
import ucar.ma2.MAMath;
2425

@@ -230,6 +231,60 @@ public void testWriteReadWithZarrita(String codec, String codecParam) throws Exc
230231
assert exitCode == 0;
231232
}
232233

234+
static Stream<int[]> invalidchunkSizes() {
235+
return Stream.of(
236+
new int[] {1} ,
237+
new int[] {1, 1, 1},
238+
new int[] {5, 1},
239+
new int[] {1, 5}
240+
);
241+
}
242+
243+
@ParameterizedTest
244+
@MethodSource("invalidchunkSizes")
245+
public void testCheckInvalidChunkBounds(int[] chunkSize) throws Exception {
246+
long[] shape = new long[] {4, 4};
247+
248+
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("invalid_chunksize");
249+
ArrayMetadataBuilder builder = Array.metadataBuilder()
250+
.withShape(shape)
251+
.withDataType(DataType.UINT32)
252+
.withChunkShape(chunkSize);
253+
254+
assertThrows(ZarrException.class, builder::build);
255+
}
256+
257+
@ParameterizedTest
258+
@ValueSource(strings = {"large", "small", "nested", "wrong dims", "correct"})
259+
public void testCheckShardingBounds(String scenario) throws Exception {
260+
long[] shape = new long[] {4, 4};
261+
int[] shardSize = new int[] {2, 2};
262+
int[] chunkSize = new int[] {2, 2};
263+
264+
if (scenario.equals("large"))
265+
shardSize = new int[] {8, 8};
266+
if (scenario.equals("small"))
267+
shardSize = new int[] {1, 1};
268+
if (scenario.equals("wrong dims"))
269+
shardSize = new int[] {1};
270+
StoreHandle storeHandle = new FilesystemStore(TESTOUTPUT).resolve("illegal_shardsize");
271+
ArrayMetadataBuilder builder = Array.metadataBuilder()
272+
.withShape(shape)
273+
.withDataType(DataType.UINT32).withChunkShape(shardSize);
274+
275+
if (scenario.equals("nested")) {
276+
int[] nestedChunkSize = new int[]{4, 4};
277+
builder = builder.withCodecs(c -> c.withSharding(chunkSize, c1 -> c1.withSharding(nestedChunkSize, c2 -> c2.withBytes("LITTLE"))));
278+
} else {
279+
builder = builder.withCodecs(c -> c.withSharding(chunkSize, c1 -> c1.withBytes("LITTLE")));
280+
}
281+
if (scenario.equals("correct")){
282+
builder.build();
283+
}else{
284+
assertThrows(ZarrException.class, builder::build);
285+
}
286+
}
287+
233288
@ParameterizedTest
234289
@CsvSource({"0,true", "0,false", "5, true", "5, false"})
235290
public void testZstdCodecReadWrite(int clevel, boolean checksum) throws ZarrException, IOException {

0 commit comments

Comments
 (0)