Skip to content

Commit 249c263

Browse files
committed
[type-classes] 8377322: Upgrade Float16 to use operators
Reviewed-by: liach
1 parent d4663b4 commit 249c263

File tree

3 files changed

+190
-5
lines changed

3 files changed

+190
-5
lines changed

src/jdk.incubator.vector/share/classes/jdk/incubator/vector/Float16.java

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2023, 2026, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -101,10 +101,106 @@
101101
// expected to be aligned with Value Classes and Object as described in
102102
// JEP-401 (https://openjdk.org/jeps/401).
103103
@jdk.internal.ValueBased
104-
public final class Float16
104+
public final /* value */ class Float16
105105
extends Number
106106
implements Comparable<Float16> {
107107

108+
private static final StandardFloatingPoint<Float16> SFP = new StandardFloatingPoint<Float16>() {
109+
public Float16 add(Float16 addend, Float16 augend) {
110+
return Float16.add(addend, augend);
111+
}
112+
113+
@Override
114+
public Float16 subtract(Float16 minuend, Float16 subtrahend) {
115+
return Float16.subtract(minuend, subtrahend);
116+
}
117+
118+
public Float16 multiply(Float16 multiplier, Float16 multiplicand) {
119+
return Float16.multiply(multiplier, multiplicand);
120+
}
121+
122+
public Float16 remainder(Float16 dividend, Float16 divisor) {
123+
throw new UnsupportedOperationException("tbd");
124+
}
125+
126+
public Float16 negate(Float16 operand) {
127+
return Float16.negate(operand);
128+
}
129+
130+
public Float16 divide(Float16 dividend, Float16 divisor) {
131+
return Float16.divide(dividend, divisor);
132+
}
133+
134+
@Override
135+
public boolean equalsStd(Float16 op1, Float16 op2) {
136+
return op1.floatValue() == op2.floatValue();
137+
}
138+
139+
@Override
140+
public boolean lessThan(Float16 op1, Float16 op2) {
141+
return op1.floatValue() < op2.floatValue();
142+
}
143+
144+
// If the following three methods are commented out, the default
145+
// implementations in StandardFloatingPoint will be used.
146+
// @Override
147+
// public boolean lessThanEqual(Float16 op1, Float16 op2) {
148+
// return op1.floatValue() <= op2.floatValue();
149+
// }
150+
151+
// @Override
152+
// public boolean greaterThan(Float16 op1, Float16 op2) {
153+
// return op1.floatValue() > op2.floatValue();
154+
// }
155+
156+
// @Override
157+
// public boolean greaterThanEqual(Float16 op1, Float16 op2) {
158+
// return op1.floatValue() >= op2.floatValue();
159+
// }
160+
161+
public Float16 min(Float16 op1, Float16 op2) {
162+
return Float16.min(op1, op2);
163+
}
164+
165+
public Float16 max(Float16 op1, Float16 op2) {
166+
return Float16.max(op1, op2);
167+
}
168+
169+
public Float16 sqrt(Float16 radicand) {
170+
return Float16.sqrt(radicand);
171+
}
172+
173+
public Float16 fma(Float16 a, Float16 b, Float16 c) {
174+
return Float16.fma(a, b, c);
175+
}
176+
177+
public boolean isNaN(Float16 operand) {
178+
return Float16.isNaN(operand);
179+
}
180+
181+
public boolean isInfinite(Float16 operand) {
182+
return Float16.isInfinite(operand);
183+
}
184+
185+
public Float16 ulp(Float16 operand) {
186+
return Float16.ulp(operand);
187+
}
188+
189+
public String toHexString(Float16 operand) {
190+
return Float16.toHexString(operand);
191+
}
192+
};
193+
194+
/**
195+
* Witness for the {@code Numerical} interface.
196+
*/
197+
public __witness Numerical<Float16> NUM = SFP;
198+
199+
/**
200+
* Witness for the {@code Orderable} interface.
201+
*/
202+
public __witness Orderable<Float16> ORD = SFP;
203+
108204
/**
109205
* Primitive {@code short} field to hold the bits of the {@code Float16}.
110206
* @serial

test/jdk/jdk/incubator/vector/BasicFloat16ArithTests.java

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2016, 2025, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2016, 2026, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -49,6 +49,7 @@ public static void main(String... args) {
4949
checkFiniteness();
5050
checkMinMax();
5151
checkArith();
52+
checkOrderable();
5253
checkSqrt();
5354
checkGetExponent();
5455
checkUlp();
@@ -154,6 +155,15 @@ private static void checkConstants() {
154155
checkFloat16(NaN, NaNf, "NaN");
155156
}
156157

158+
private static void checkBoolean(Float16 op1, Float16 op2, boolean result, boolean expected, String operator) {
159+
if (result != expected) {
160+
throwRE(String.format("Didn't get expected value for " +
161+
"%s %s %s %nexpected %b, got %b%n",
162+
op1, operator, op2,
163+
expected, result));
164+
}
165+
}
166+
157167
private static void checkInt(int value, int expected, String message) {
158168
if (value != expected) {
159169
throwRE(String.format("Didn't get expected value for %s;%nexpected %d, got %d",
@@ -185,12 +195,18 @@ private static void checkNegate() {
185195

186196
for(var testCase : testCases) {
187197
float arg = testCase[0];
198+
Float16 argF16 = valueOfExact(arg);
188199
float expected = testCase[1];
189200
Float16 result = negate(valueOfExact(arg));
201+
Float16 resultOp = -argF16;
190202

191203
if (Float.compare(expected, result.floatValue()) != 0) {
192204
checkFloat16(result, expected, "negate(" + arg + ")");
193205
}
206+
207+
if (Float.compare(expected, resultOp.floatValue()) != 0) {
208+
checkFloat16(result, expected, "negate(" + arg + ")");
209+
}
194210
}
195211

196212
return;
@@ -314,7 +330,8 @@ private static void checkMinMax() {
314330

315331
/*
316332
* Cursory checks to make sure correct operation is being called
317-
* with arguments in proper order.
333+
* with arguments in proper order for both two-argument methods
334+
* and binary operators of the Numerical interface.
318335
*/
319336
private static void checkArith() {
320337
float a = 1.0f;
@@ -323,33 +340,105 @@ private static void checkArith() {
323340
float b = 2.0f;
324341
Float16 b16 = valueOfExact(b);
325342

343+
// Addition
326344
if (add(a16, b16).floatValue() != (a + b)) {
327345
throwRE("failure with " + a16 + " + " + b16);
328346
}
347+
if ((a16 + b16).floatValue() != (a + b)) { // check + operator
348+
throwRE("failure with " + a16 + " + " + b16);
349+
}
350+
329351
if (add(b16, a16).floatValue() != (b + a)) {
330352
throwRE("failure with " + b16 + " + " + a16);
331353
}
354+
if ((b16 + a16).floatValue() != (b + a)) { // check + operator
355+
throwRE("failure with " + b16 + " + " + a16);
356+
}
332357

358+
// Subtraction
333359
if (subtract(a16, b16).floatValue() != (a - b)) {
334360
throwRE("failure with " + a16 + " - " + b16);
335361
}
362+
if ((a16 - b16).floatValue() != (a - b)) { // check - operator
363+
throwRE("failure with " + a16 + " - " + b16);
364+
}
365+
336366
if (subtract(b16, a16).floatValue() != (b - a)) {
337367
throwRE("failure with " + b16 + " - " + a16);
338368
}
369+
if ((b16 - a16).floatValue() != (b - a)) { // check - operator
370+
throwRE("failure with " + b16 + " - " + a16);
371+
}
339372

373+
// Multiplication
340374
if (multiply(a16, b16).floatValue() != (a * b)) {
341375
throwRE("failure with " + a16 + " * " + b16);
342376
}
377+
if ((a16 * b16).floatValue() != (a * b)) { // check * operator
378+
throwRE("failure with " + a16 + " * " + b16);
379+
}
380+
343381
if (multiply(b16, a16).floatValue() != (b * a)) {
344382
throwRE("failure with " + b16 + " * " + a16);
345383
}
384+
if ((b16 * a16).floatValue() != (b * a)) { // check * operator
385+
throwRE("failure with " + b16 + " * " + a16);
386+
}
346387

388+
// Division
347389
if (divide(a16, b16).floatValue() != (a / b)) {
348390
throwRE("failure with " + a16 + " / " + b16);
349391
}
392+
if ((a16 / b16).floatValue() != (a / b)) { // check / operator
393+
throwRE("failure with " + a16 + " / " + b16);
394+
}
395+
350396
if (divide(b16, a16).floatValue() != (b / a)) {
351397
throwRE("failure with " + b16 + " / " + a16);
352398
}
399+
if ((b16 / a16).floatValue() != (b / a)) { // check / operator
400+
throwRE("failure with " + b16 + " / " + a16);
401+
}
402+
403+
return;
404+
}
405+
406+
/*
407+
* Cursory checks to make sure the ordered comparison operators
408+
* are behaving as expected.
409+
*/
410+
private static void checkOrderable() {
411+
float[] testCases = {NaNf,
412+
-InfinityF,
413+
-1.0f,
414+
-0.0f,
415+
+0.0f,
416+
1.0f,
417+
InfinityF};
418+
419+
for (float op1_f : testCases) {
420+
for (float op2_f : testCases) {
421+
422+
Float16 op1_f16 = valueOfExact(op1_f);
423+
Float16 op2_f16 = valueOfExact(op2_f);
424+
425+
checkBoolean(op1_f16, op2_f16,
426+
op1_f16 < op2_f16,
427+
op1_f < op2_f, "<");
428+
429+
checkBoolean(op1_f16, op2_f16,
430+
op1_f16 <= op2_f16,
431+
op1_f <= op2_f, "<=");
432+
433+
checkBoolean(op1_f16, op2_f16,
434+
op1_f16 > op2_f16,
435+
op1_f > op2_f, ">");
436+
437+
checkBoolean(op1_f16, op2_f16,
438+
op1_f16 >= op2_f16,
439+
op1_f >= op2_f, ">=");
440+
}
441+
}
353442
return;
354443
}
355444

test/jdk/jdk/incubator/vector/Bfloat16/Bfloat16.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public final value class Bfloat16
9494
extends Number
9595
implements Comparable<Bfloat16> {
9696

97-
private static StandardFloatingPoint<Bfloat16> SFP = new StandardFloatingPoint<Bfloat16>() {
97+
private static final StandardFloatingPoint<Bfloat16> SFP = new StandardFloatingPoint<Bfloat16>() {
9898
public Bfloat16 add(Bfloat16 addend, Bfloat16 augend) {
9999
return Bfloat16.add(addend, augend);
100100
}

0 commit comments

Comments
 (0)