Skip to content

Commit 6e51602

Browse files
authored
Merge pull request #19 from asdf-format/eslavich/float16-data-type
Support float16 ndarray data type
2 parents 27373a1 + f71f169 commit 6e51602

14 files changed

Lines changed: 508 additions & 1 deletion

File tree

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ replay_pid*
2626
build
2727

2828
.flattened-pom.xml
29+
30+
.classpath
31+
.factorypath
32+
.project
33+
.settings

asdf-core/src/main/java/org/asdfformat/asdf/io/impl/InlineBlockV1_0_0.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.asdfformat.asdf.ndarray.DataType;
55
import org.asdfformat.asdf.ndarray.DataTypeFamilyType;
66
import org.asdfformat.asdf.ndarray.DataTypes;
7+
import org.asdfformat.asdf.ndarray.impl.Float16Utils;
78
import org.asdfformat.asdf.node.AsdfNode;
89
import org.asdfformat.asdf.util.AsdfCharsets;
910

@@ -25,6 +26,7 @@ public class InlineBlockV1_0_0 implements Block {
2526
SIMPLE_VALUE_WRITERS.put(DataTypes.INT32, (b, n) -> b.putInt(n.asInt()));
2627
SIMPLE_VALUE_WRITERS.put(DataTypes.UINT64, (b, n) -> b.putLong(n.asLong()));
2728
SIMPLE_VALUE_WRITERS.put(DataTypes.INT64, (b, n) -> b.putLong(n.asLong()));
29+
SIMPLE_VALUE_WRITERS.put(DataTypes.FLOAT16, (b, n) -> b.putShort(Float16Utils.floatToFloat16(n.asFloat())));
2830
SIMPLE_VALUE_WRITERS.put(DataTypes.FLOAT32, (b, n) -> b.putFloat(n.asFloat()));
2931
SIMPLE_VALUE_WRITERS.put(DataTypes.FLOAT64, (b, n) -> b.putDouble(n.asDouble()));
3032
SIMPLE_VALUE_WRITERS.put(DataTypes.BOOL8, (b, n) -> b.put((byte)(n.asBoolean() ? 1 : 0)));

asdf-core/src/main/java/org/asdfformat/asdf/ndarray/DataTypes.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ public class DataTypes {
5858
new HashSet<>(Arrays.asList(Long.TYPE, BigInteger.class))
5959
);
6060

61+
public static final DataType FLOAT16 = new SimpleDataTypeImpl(
62+
DataTypeFamilyType.FLOAT,
63+
2,
64+
new HashSet<>(Arrays.asList(Float.TYPE, Double.TYPE, BigDecimal.class))
65+
);
66+
6167
public static final DataType FLOAT32 = new SimpleDataTypeImpl(
6268
DataTypeFamilyType.FLOAT,
6369
4,

asdf-core/src/main/java/org/asdfformat/asdf/ndarray/impl/BigDecimalNdArrayImpl.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import java.util.function.BiConsumer;
1313
import java.util.function.Function;
1414

15+
import static org.asdfformat.asdf.ndarray.impl.Float16Utils.float16ToFloat;
16+
1517
public class BigDecimalNdArrayImpl extends NdArrayBase<BigDecimalNdArray> implements BigDecimalNdArray {
1618
public BigDecimalNdArrayImpl(final DataType dataType, final int[] shape, final ByteOrder byteOrder, final int[] strides, final int offset, final Block block) {
1719
super(dataType, shape, byteOrder, strides, offset, block);
@@ -46,6 +48,8 @@ public BigDecimal get(final int... indices) {
4648
} else {
4749
throw new RuntimeException("Unhandled datatype: " + dataType);
4850
}
51+
} else if (dataType.equals(DataTypes.FLOAT16)) {
52+
return BigDecimal.valueOf(float16ToFloat(byteBuffer.getShort()));
4953
} else if (dataType.equals(DataTypes.FLOAT32)) {
5054
return BigDecimal.valueOf(byteBuffer.getFloat());
5155
} else if (dataType.equals(DataTypes.FLOAT64)) {
@@ -84,6 +88,12 @@ public <ARRAY> ARRAY toArray(final ARRAY array) {
8488
arr[index + i] = valueCreator.apply(buffer);
8589
}
8690
};
91+
} else if (dataType.equals(DataTypes.FLOAT16)) {
92+
setter = (byteBuffer, arr, index, length) -> {
93+
for (int i = 0; i < length; i++) {
94+
arr[index + i] = BigDecimal.valueOf(float16ToFloat(byteBuffer.getShort()));
95+
}
96+
};
8797
} else if (dataType.equals(DataTypes.FLOAT32)) {
8898
setter = (byteBuffer, arr, index, length) -> {
8999
for (int i = 0; i < length; i++) {

asdf-core/src/main/java/org/asdfformat/asdf/ndarray/impl/DoubleNdArrayImpl.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import java.nio.ByteBuffer;
99
import java.nio.ByteOrder;
1010

11+
import static org.asdfformat.asdf.ndarray.impl.Float16Utils.float16ToFloat;
12+
1113
public class DoubleNdArrayImpl extends NdArrayBase<DoubleNdArray> implements DoubleNdArray {
1214
public DoubleNdArrayImpl(final DataType dataType, final int[] shape, final ByteOrder byteOrder, final int[] strides, final int offset, final Block block) {
1315
super(dataType, shape, byteOrder, strides, offset, block);
@@ -34,6 +36,8 @@ public double get(int... indices) {
3436
return byteBuffer.getDouble();
3537
} else if (dataType.equals(DataTypes.FLOAT32)) {
3638
return byteBuffer.getFloat();
39+
} else if (dataType.equals(DataTypes.FLOAT16)) {
40+
return float16ToFloat(byteBuffer.getShort());
3741
} else {
3842
throw new RuntimeException("Unhandled datatype: " + dataType);
3943
}
@@ -52,6 +56,12 @@ public <ARRAY> ARRAY toArray(final ARRAY array) {
5256
arr[index + i] = floatArr[i];
5357
}
5458
};
59+
} else if (dataType.equals(DataTypes.FLOAT16)) {
60+
setter = (byteBuffer, arr, index, length) -> {
61+
for (int i = 0; i < length; i++) {
62+
arr[index + i] = float16ToFloat(byteBuffer.getShort());
63+
}
64+
};
5565
} else {
5666
throw new RuntimeException("Unhandled datatype: " + dataType);
5767
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package org.asdfformat.asdf.ndarray.impl;
2+
3+
public class Float16Utils {
4+
public static float float16ToFloat(final short bits) {
5+
final int halfBits = bits & 0xFFFF;
6+
final int sign = (halfBits >>> 15) & 0x1;
7+
final int exponent = (halfBits >>> 10) & 0x1F;
8+
final int mantissa = halfBits & 0x3FF;
9+
10+
final int floatBits;
11+
if (exponent == 0) {
12+
if (mantissa == 0) {
13+
floatBits = sign << 31;
14+
} else {
15+
// Subnormal: normalize by shifting mantissa until the leading 1 is in bit 10
16+
int m = mantissa;
17+
int e = -14 + 127;
18+
while ((m & 0x400) == 0) {
19+
m <<= 1;
20+
e--;
21+
}
22+
m &= 0x3FF;
23+
floatBits = (sign << 31) | (e << 23) | (m << 13);
24+
}
25+
} else if (exponent == 31) {
26+
// Inf or NaN: rebased exponent to float32's 255
27+
floatBits = (sign << 31) | (0xFF << 23) | (mantissa << 13);
28+
} else {
29+
// Normal: rebase exponent from bias-15 to bias-127
30+
final int floatExponent = exponent - 15 + 127;
31+
floatBits = (sign << 31) | (floatExponent << 23) | (mantissa << 13);
32+
}
33+
34+
return Float.intBitsToFloat(floatBits);
35+
}
36+
37+
public static short floatToFloat16(final float value) {
38+
final int floatBits = Float.floatToIntBits(value);
39+
final int sign = (floatBits >>> 31) & 0x1;
40+
final int exponent = (floatBits >>> 23) & 0xFF;
41+
final int mantissa = floatBits & 0x7FFFFF;
42+
43+
final int halfBits;
44+
if (exponent == 0) {
45+
halfBits = sign << 15;
46+
} else if (exponent == 0xFF) {
47+
if (mantissa == 0) {
48+
halfBits = (sign << 15) | (0x1F << 10);
49+
} else {
50+
final int halfMantissa = mantissa >>> 13;
51+
halfBits = (sign << 15) | (0x1F << 10) | (halfMantissa != 0 ? halfMantissa : 0x1);
52+
}
53+
} else {
54+
final int halfExponent = exponent - 127 + 15;
55+
if (halfExponent >= 31) {
56+
halfBits = (sign << 15) | (0x1F << 10);
57+
} else if (halfExponent <= 0) {
58+
if (halfExponent < -10) {
59+
halfBits = sign << 15;
60+
} else {
61+
final int shift = 1 - halfExponent + 13;
62+
final int m = (mantissa | 0x800000) >>> shift;
63+
final int roundBit = ((mantissa | 0x800000) >>> (shift - 1)) & 0x1;
64+
final int stickyBit = ((mantissa | 0x800000) & ((1 << (shift - 1)) - 1)) != 0 ? 1 : 0;
65+
final int rounded = m + (roundBit & (stickyBit | (m & 1)));
66+
halfBits = (sign << 15) | rounded;
67+
}
68+
} else {
69+
final int truncated = (mantissa >>> 13);
70+
final int roundBit = (mantissa >>> 12) & 0x1;
71+
final int stickyBit = (mantissa & 0xFFF) != 0 ? 1 : 0;
72+
final int rounded = truncated + (roundBit & (stickyBit | (truncated & 1)));
73+
halfBits = (sign << 15) | ((halfExponent << 10) + rounded);
74+
}
75+
}
76+
77+
return (short) halfBits;
78+
}
79+
80+
private Float16Utils() {}
81+
}

asdf-core/src/main/java/org/asdfformat/asdf/ndarray/impl/FloatNdArrayImpl.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import java.nio.ByteBuffer;
99
import java.nio.ByteOrder;
1010

11+
import static org.asdfformat.asdf.ndarray.impl.Float16Utils.float16ToFloat;
12+
1113
public class FloatNdArrayImpl extends NdArrayBase<FloatNdArray> implements FloatNdArray {
1214
public FloatNdArrayImpl(final DataType dataType, final int[] shape, final ByteOrder byteOrder, final int[] strides, final int offset, final Block block) {
1315
super(dataType, shape, byteOrder, strides, offset, block);
@@ -32,14 +34,27 @@ public float get(int... indices) {
3234
final ByteBuffer byteBuffer = getByteBufferAt(indices);
3335
if (dataType.equals(DataTypes.FLOAT32)) {
3436
return byteBuffer.getFloat();
37+
} else if (dataType.equals(DataTypes.FLOAT16)) {
38+
return float16ToFloat(byteBuffer.getShort());
3539
} else {
3640
throw new RuntimeException("Unhandled datatype: " + dataType);
3741
}
3842
}
3943

4044
@Override
4145
public <ARRAY> ARRAY toArray(final ARRAY array) {
42-
final ArraySetter<float[]> setter = (byteBuffer, arr, index, length) -> byteBuffer.asFloatBuffer().get(arr, index, length);
46+
final ArraySetter<float[]> setter;
47+
if (dataType.equals(DataTypes.FLOAT32)) {
48+
setter = (byteBuffer, arr, index, length) -> byteBuffer.asFloatBuffer().get(arr, index, length);
49+
} else if (dataType.equals(DataTypes.FLOAT16)) {
50+
setter = (byteBuffer, arr, index, length) -> {
51+
for (int i = 0; i < length; i++) {
52+
arr[index + i] = float16ToFloat(byteBuffer.getShort());
53+
}
54+
};
55+
} else {
56+
throw new RuntimeException("Unhandled datatype: " + dataType);
57+
}
4358
return toArray(array, Float.TYPE, setter);
4459
}
4560
}

asdf-core/src/main/java/org/asdfformat/asdf/standard/impl/NdArrayHandler_1_x.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public class NdArrayHandler_1_x implements NdArrayHandler {
3535
SIMPLE_DATA_TYPES.put("uint8", DataTypes.UINT8);
3636
SIMPLE_DATA_TYPES.put("uint16", DataTypes.UINT16);
3737
SIMPLE_DATA_TYPES.put("uint32", DataTypes.UINT32);
38+
SIMPLE_DATA_TYPES.put("float16", DataTypes.FLOAT16);
3839
SIMPLE_DATA_TYPES.put("float32", DataTypes.FLOAT32);
3940
SIMPLE_DATA_TYPES.put("float64", DataTypes.FLOAT64);
4041
SIMPLE_DATA_TYPES.put("complex64", DataTypes.COMPLEX64);
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package org.asdfformat.asdf.ndarray;
2+
3+
import org.asdfformat.asdf.Asdf;
4+
import org.asdfformat.asdf.AsdfFile;
5+
import org.asdfformat.asdf.standard.AsdfStandardType;
6+
import org.asdfformat.asdf.testing.CoreReferenceFileType;
7+
import org.asdfformat.asdf.testing.ReferenceFileUtils;
8+
import org.asdfformat.asdf.util.Version;
9+
import org.junit.jupiter.api.Tag;
10+
import org.junit.jupiter.params.ParameterizedTest;
11+
import org.junit.jupiter.params.provider.Arguments;
12+
import org.junit.jupiter.params.provider.MethodSource;
13+
14+
import java.io.IOException;
15+
import java.math.BigDecimal;
16+
import java.nio.file.Path;
17+
import java.util.Arrays;
18+
import java.util.stream.Stream;
19+
20+
import static org.asdfformat.asdf.testing.TestCategories.REFERENCE_TESTS;
21+
import static org.junit.jupiter.api.Assertions.assertEquals;
22+
import static org.junit.jupiter.api.Assertions.assertTrue;
23+
24+
@Tag(REFERENCE_TESTS)
25+
public class NdArrayFloat16ReferenceTest {
26+
private static final Version FLOAT16_MIN_VERSION = new Version(1, 6, 0);
27+
28+
private static final CoreReferenceFileType[] FILE_TYPES = {
29+
CoreReferenceFileType.NDARRAY_FLOAT16_1D_BLOCK_BIG,
30+
CoreReferenceFileType.NDARRAY_FLOAT16_1D_BLOCK_LITTLE,
31+
CoreReferenceFileType.NDARRAY_FLOAT16_1D_INLINE,
32+
};
33+
34+
private static Stream<Arguments> float16Args() {
35+
return Arrays.stream(FILE_TYPES)
36+
.flatMap(fileType -> Arrays.stream(AsdfStandardType.values())
37+
.filter(std -> std.getVersion().compareTo(FLOAT16_MIN_VERSION) >= 0)
38+
.map(std -> Arguments.of(fileType, std)));
39+
}
40+
41+
@ParameterizedTest
42+
@MethodSource("float16Args")
43+
public void testFloat1d(final CoreReferenceFileType coreTestFileType, final AsdfStandardType asdfStandardType) throws IOException {
44+
final Path path = ReferenceFileUtils.getPath(coreTestFileType, asdfStandardType.getVersion());
45+
46+
try (final AsdfFile asdfFile = Asdf.open(path)) {
47+
final FloatNdArray floatNdArray = asdfFile.getTree().get("arr").asNdArray().asFloatNdArray();
48+
49+
assertEquals(-65504.0f, floatNdArray.get(0));
50+
assertEquals(65504.0f, floatNdArray.get(1));
51+
assertEquals(5.9604645E-8f, floatNdArray.get(2));
52+
assertEquals(0.0f, floatNdArray.get(3));
53+
assertTrue(Float.isNaN(floatNdArray.get(4)));
54+
assertEquals(Float.POSITIVE_INFINITY, floatNdArray.get(5));
55+
assertEquals(Float.NEGATIVE_INFINITY, floatNdArray.get(6));
56+
assertEquals(3.140625f, floatNdArray.get(7));
57+
assertEquals(-3.140625f, floatNdArray.get(8));
58+
59+
final float[] arr = floatNdArray.toArray(new float[9]);
60+
assertEquals(-65504.0f, arr[0]);
61+
assertEquals(65504.0f, arr[1]);
62+
assertEquals(5.9604645E-8f, arr[2]);
63+
assertEquals(0.0f, arr[3]);
64+
assertTrue(Float.isNaN(arr[4]));
65+
assertEquals(Float.POSITIVE_INFINITY, arr[5]);
66+
assertEquals(Float.NEGATIVE_INFINITY, arr[6]);
67+
assertEquals(3.140625f, arr[7]);
68+
assertEquals(-3.140625f, arr[8]);
69+
}
70+
}
71+
72+
@ParameterizedTest
73+
@MethodSource("float16Args")
74+
public void testDouble1d(final CoreReferenceFileType coreTestFileType, final AsdfStandardType asdfStandardType) throws IOException {
75+
final Path path = ReferenceFileUtils.getPath(coreTestFileType, asdfStandardType.getVersion());
76+
77+
try (final AsdfFile asdfFile = Asdf.open(path)) {
78+
final DoubleNdArray doubleNdArray = asdfFile.getTree().get("arr").asNdArray().asDoubleNdArray();
79+
80+
assertEquals(-65504.0, doubleNdArray.get(0));
81+
assertEquals(65504.0, doubleNdArray.get(1));
82+
assertEquals(5.960464477539063E-8, doubleNdArray.get(2));
83+
assertEquals(0.0, doubleNdArray.get(3));
84+
assertTrue(Double.isNaN(doubleNdArray.get(4)));
85+
assertEquals(Double.POSITIVE_INFINITY, doubleNdArray.get(5));
86+
assertEquals(Double.NEGATIVE_INFINITY, doubleNdArray.get(6));
87+
assertEquals(3.140625, doubleNdArray.get(7));
88+
assertEquals(-3.140625, doubleNdArray.get(8));
89+
90+
final double[] arr = doubleNdArray.toArray(new double[9]);
91+
assertEquals(-65504.0, arr[0]);
92+
assertEquals(65504.0, arr[1]);
93+
assertEquals(5.960464477539063E-8, arr[2]);
94+
assertEquals(0.0, arr[3]);
95+
assertTrue(Double.isNaN(arr[4]));
96+
assertEquals(Double.POSITIVE_INFINITY, arr[5]);
97+
assertEquals(Double.NEGATIVE_INFINITY, arr[6]);
98+
assertEquals(3.140625, arr[7]);
99+
assertEquals(-3.140625, arr[8]);
100+
}
101+
}
102+
103+
@ParameterizedTest
104+
@MethodSource("float16Args")
105+
public void testBigDecimal1d(final CoreReferenceFileType coreTestFileType, final AsdfStandardType asdfStandardType) throws IOException {
106+
final Path path = ReferenceFileUtils.getPath(coreTestFileType, asdfStandardType.getVersion());
107+
108+
try (final AsdfFile asdfFile = Asdf.open(path)) {
109+
final BigDecimalNdArray bigDecimalNdArray = asdfFile.getTree().get("arr").asNdArray().asBigDecimalNdArray();
110+
111+
assertEquals(BigDecimal.valueOf(-65504.0), bigDecimalNdArray.get(0));
112+
assertEquals(BigDecimal.valueOf(65504.0), bigDecimalNdArray.get(1));
113+
assertEquals(BigDecimal.valueOf(5.960464477539063E-8), bigDecimalNdArray.get(2));
114+
assertEquals(BigDecimal.valueOf(0.0), bigDecimalNdArray.get(3));
115+
assertEquals(BigDecimal.valueOf(3.140625), bigDecimalNdArray.get(7));
116+
assertEquals(BigDecimal.valueOf(-3.140625), bigDecimalNdArray.get(8));
117+
118+
}
119+
}
120+
}

0 commit comments

Comments
 (0)