Skip to content

Commit 9ae803a

Browse files
committed
Add fp16 addition and sqrt()
1 parent 22bbf07 commit 9ae803a

File tree

3 files changed

+232
-33
lines changed

3 files changed

+232
-33
lines changed

src/commonMain/kotlin/dev/romainguy/kotlin/math/Half.kt

+117-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
* limitations under the License.
1515
*/
1616

17+
// Operators +, *, / based on http://half.sourceforge.net/ by Christian Rau
18+
// and licensed under MIT
19+
1720
@file:Suppress("NOTHING_TO_INLINE")
1821

1922
package dev.romainguy.kotlin.math
@@ -344,7 +347,7 @@ value class Half(private val v: UShort) : Comparable<Half> {
344347
get() = when {
345348
isNaN() -> NaN
346349
isInfinite() -> POSITIVE_INFINITY
347-
// 0x7bff == MAX_VALUE
350+
// 0x7bff == MAX_VALUE, return 2^4
348351
v.toInt() and FP16_ABS == 0x7bff -> Half(0x4c00.toUShort())
349352
else -> {
350353
val d = absoluteValue
@@ -465,7 +468,7 @@ value class Half(private val v: UShort) : Comparable<Half> {
465468
fun nextUp(): Half = when {
466469
isNaN() || v == POSITIVE_INFINITY.v -> this
467470
isZero() -> MIN_VALUE
468-
else -> fromBits(toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) 1 else -1)
471+
else -> Half((toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) 1 else -1).toUShort())
469472
}
470473

471474
/**
@@ -474,7 +477,7 @@ value class Half(private val v: UShort) : Comparable<Half> {
474477
fun nextDown(): Half = when {
475478
isNaN() || v == NEGATIVE_INFINITY.v -> this
476479
isZero() -> -MIN_VALUE
477-
else -> fromBits(toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) -1 else 1)
480+
else -> Half((toBits() + if (v.toInt() and FP16_SIGN_MASK == 0) -1 else 1).toUShort())
478481
}
479482

480483
/**
@@ -519,7 +522,73 @@ value class Half(private val v: UShort) : Comparable<Half> {
519522
operator fun unaryPlus() = Half(v)
520523

521524
operator fun plus(other: Half): Half {
522-
TODO("Not yet implemented")
525+
val xbits = toBits()
526+
val ybits = other.toBits()
527+
528+
val sub = ((xbits xor ybits) and FP16_SIGN_MASK) != 0
529+
530+
var ax = xbits and FP16_ABS
531+
var ay = ybits and FP16_ABS
532+
533+
// Handle NaNs and infinities
534+
if (ax >= FP16_EXPONENT_MAX || ay >= FP16_EXPONENT_MAX) {
535+
return Half((
536+
if (ax > FP16_EXPONENT_MAX || ay > FP16_EXPONENT_MAX) quiet(ax, ay)
537+
else if (ay != FP16_EXPONENT_MAX) xbits
538+
else if (sub && ax == FP16_EXPONENT_MAX) FP16_QUIET_NAN
539+
else ybits
540+
).toUShort())
541+
}
542+
543+
// Handle zero operands, including signs
544+
if (ax == 0) return if (ay != 0) other else Half((xbits and ybits).toUShort())
545+
if (ay == 0) return this
546+
547+
// Compute the sign of the result
548+
val s = (if (sub && ay > ax) ybits else xbits) and FP16_SIGN_MASK
549+
550+
if (ay > ax) {
551+
val t = ax
552+
ax = ay
553+
ay = t
554+
}
555+
556+
var e = (ax shr 10) + if (ax <= FP16_SIGNIFICAND_MASK) 1 else 0
557+
val d = e - (ay shr 10) - if (ay <= FP16_SIGNIFICAND_MASK) 1 else 0
558+
559+
var mx = ((ax and FP16_SIGNIFICAND_MASK) or
560+
((if (ax > FP16_SIGNIFICAND_MASK) 1 else 0) shl 10)) shl 3
561+
var my: Int
562+
563+
if (d < 13) {
564+
my = ((ay and FP16_SIGNIFICAND_MASK) or
565+
((if (ay > FP16_SIGNIFICAND_MASK) 1 else 0) shl 10)) shl 3
566+
my = (my shr d) or (if ((my and ((1 shl d) - 1)) != 0) 1 else 0)
567+
} else {
568+
my = 1
569+
}
570+
571+
if (sub) {
572+
mx -= my
573+
if (mx == 0) return POSITIVE_ZERO
574+
while (mx < 0x2000 && e > 1) {
575+
mx = mx shl 1
576+
e--
577+
}
578+
} else {
579+
mx += my
580+
val i = mx shr 14
581+
e += i
582+
if (e > 30) return Half((s or FP16_EXPONENT_MAX).toUShort())
583+
mx = (mx shr i) or (mx and i)
584+
}
585+
586+
// Guard and sticky bits
587+
val v = s +((e - 1) shl 10) + (mx shr 3)
588+
val G = (mx shr 2) and 1
589+
val S = if (mx and 0x3 != 0) 1 else 0
590+
591+
return Half((v + (G and (S or v))).toUShort())
523592
}
524593

525594
operator fun minus(other: Half) = this + (-other)
@@ -704,7 +773,50 @@ value class Half(private val v: UShort) : Comparable<Half> {
704773
}
705774
}
706775

707-
fun sqrt(x: Half): Half = TODO("Not implemented yet")
776+
fun sqrt(x: Half): Half {
777+
val bits = x.toBits()
778+
var a = bits and FP16_ABS
779+
var e = 15
780+
781+
if (a == 0 || a >= FP16_EXPONENT_MAX) {
782+
return Half((when {
783+
a > FP16_EXPONENT_MAX -> quiet(bits)
784+
bits > FP16_SIGN_MASK -> FP16_QUIET_NAN
785+
else -> bits
786+
}).toUShort())
787+
}
788+
789+
while (a < 0x400) {
790+
a = a shl 1
791+
e--
792+
}
793+
794+
// Bring back 1.
795+
var r = ((a and FP16_SIGNIFICAND_MASK) or 0x400).toUInt() shl 10
796+
e += a shr 10
797+
val i = e and 1
798+
r = r shl i
799+
e = (e - i) / 2
800+
801+
var m = 0U
802+
var b = 1U shl 20
803+
while (b != 0U) {
804+
if (r < m + b) {
805+
m = m shr 1
806+
} else {
807+
r -= m + b
808+
m = (m shr 1) + b
809+
}
810+
b = b shr 2
811+
}
812+
813+
// Guard and sticky bits
814+
val v = (e shl 10).toUInt() + (m and 0x3ffU)
815+
val G = if (r > m) 1U else 0U
816+
val S = if (r != 0U) 1U else 0U
817+
818+
return Half((v + (G and (S or v))).toUShort())
819+
}
708820

709821
/**
710822
* Returns the absolute value of the specified half-precision float.

src/commonMain/kotlin/dev/romainguy/kotlin/math/Scalar.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ const val INV_PI = 1.0f / PI
2828
const val INV_TWO_PI = INV_PI * 0.5f
2929
const val INV_FOUR_PI = INV_PI * 0.25f
3030

31-
val HALF_ONE = Half(1.0f)
32-
val HALF_TWO = Half(2.0f)
31+
val HALF_ONE = Half(0x3c00.toUShort())
32+
val HALF_TWO = Half(0x4000.toUShort())
3333

3434
inline fun clamp(x: Float, min: Float, max: Float) = if (x < min) min else (if (x > max) max else x)
3535

src/commonTest/kotlin/dev/romainguy/kotlin/math/HalfTest.kt

+113-26
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,17 @@ class HalfTest {
525525

526526
assertEquals(Half.MIN_VALUE, Half.NEGATIVE_ZERO.nextUp())
527527
assertEquals(-Half.MIN_VALUE, Half.NEGATIVE_ZERO.nextDown())
528+
529+
assertTrue(Half.NaN.nextTowards(HALF_TWO).isNaN())
530+
assertTrue(HALF_TWO.nextTowards(Half.NaN).isNaN())
531+
assertEquals(HALF_ONE, HALF_ONE.nextTowards(HALF_ONE))
532+
assertEquals(-HALF_ONE, (-HALF_ONE).nextTowards(-HALF_ONE))
533+
534+
assertEquals(Half(1025.0f), Half(1024.0f).nextTowards(Half(32768.0f)))
535+
assertEquals(Half(1023.5f), Half(1024.0f).nextTowards(Half(-32768.0f)))
536+
537+
assertEquals(Half(0.50048830f), Half(0.5f).nextTowards(Half(32768.0f)))
538+
assertEquals(Half(0.49975586f), Half(0.5f).nextTowards(Half(-32768.0f)))
528539
}
529540

530541
@Test
@@ -549,8 +560,8 @@ class HalfTest {
549560

550561
@Test
551562
fun multiplication() {
552-
assertTrue((Half(2.0f) * Half.NaN).isNaN())
553-
assertTrue((Half.NaN * Half(2.0f)).isNaN())
563+
assertTrue((HALF_TWO * Half.NaN).isNaN())
564+
assertTrue((Half.NaN * HALF_TWO).isNaN())
554565
assertTrue((Half.POSITIVE_INFINITY * Half.NaN).isNaN())
555566
assertTrue((Half.NaN * Half.POSITIVE_INFINITY).isNaN())
556567
assertTrue((Half.NEGATIVE_INFINITY * Half.NaN).isNaN())
@@ -560,29 +571,29 @@ class HalfTest {
560571
assertTrue((Half.NEGATIVE_ZERO * Half.NaN).isNaN())
561572
assertTrue((Half.NaN * Half.NEGATIVE_ZERO).isNaN())
562573

563-
assertTrue((Half(2.0f) * Half.POSITIVE_INFINITY).isInfinite())
564-
assertTrue((Half.POSITIVE_INFINITY * Half(2.0f)).isInfinite())
574+
assertTrue((HALF_TWO * Half.POSITIVE_INFINITY).isInfinite())
575+
assertTrue((Half.POSITIVE_INFINITY * HALF_TWO).isInfinite())
565576

566-
assertTrue((Half(2.0f) * Half.NEGATIVE_INFINITY).isInfinite())
567-
assertTrue((Half.NEGATIVE_INFINITY * Half(2.0f)).isInfinite())
577+
assertTrue((HALF_TWO * Half.NEGATIVE_INFINITY).isInfinite())
578+
assertTrue((Half.NEGATIVE_INFINITY * HALF_TWO).isInfinite())
568579

569-
assertTrue((Half(2.0f) * Half.POSITIVE_ZERO).isZero())
570-
assertTrue((Half.POSITIVE_ZERO * Half(2.0f)).isZero())
580+
assertTrue((HALF_TWO * Half.POSITIVE_ZERO).isZero())
581+
assertTrue((Half.POSITIVE_ZERO * HALF_TWO).isZero())
571582

572-
assertTrue((Half(2.0f) * Half.NEGATIVE_ZERO).isZero())
573-
assertTrue((Half.NEGATIVE_ZERO * Half(2.0f)).isZero())
583+
assertTrue((HALF_TWO * Half.NEGATIVE_ZERO).isZero())
584+
assertTrue((Half.NEGATIVE_ZERO * HALF_TWO).isZero())
574585

575586
// Overflow
576-
assertEquals(Half.POSITIVE_INFINITY, Half(2.0f) * Half.MAX_VALUE)
577-
assertEquals(Half.POSITIVE_INFINITY, Half.MAX_VALUE * Half(2.0f))
587+
assertEquals(Half.POSITIVE_INFINITY, HALF_TWO * Half.MAX_VALUE)
588+
assertEquals(Half.POSITIVE_INFINITY, Half.MAX_VALUE * HALF_TWO)
578589
assertEquals(Half.NEGATIVE_INFINITY, Half(-2.0f) * Half.MAX_VALUE)
579590
assertEquals(Half.NEGATIVE_INFINITY, Half.MAX_VALUE * Half(-2.0f))
580591

581592
// Underflow
582593
assertEquals(Half.POSITIVE_ZERO, Half.MIN_VALUE * Half.MIN_NORMAL)
583594
assertEquals(Half.NEGATIVE_ZERO, Half.MIN_VALUE * -Half.MIN_NORMAL)
584595

585-
assertEquals(Half(8.0f), Half(2.0f) * Half(4.0f))
596+
assertEquals(Half(8.0f), HALF_TWO * Half(4.0f))
586597
assertEquals(Half(2.88f), Half(1.2f) * Half(2.4f))
587598
assertEquals(Half(-2.88f), Half(1.2f) * Half(-2.4f))
588599
assertEquals(Half(-2.88f), Half(-1.2f) * Half(2.4f))
@@ -597,8 +608,8 @@ class HalfTest {
597608

598609
@Test
599610
fun division() {
600-
assertTrue((Half(2.0f) / Half.NaN).isNaN())
601-
assertTrue((Half.NaN / Half(2.0f)).isNaN())
611+
assertTrue((HALF_TWO / Half.NaN).isNaN())
612+
assertTrue((Half.NaN / HALF_TWO).isNaN())
602613
assertTrue((Half.POSITIVE_INFINITY / Half.NaN).isNaN())
603614
assertTrue((Half.NaN / Half.POSITIVE_INFINITY).isNaN())
604615
assertTrue((Half.NEGATIVE_INFINITY / Half.NaN).isNaN())
@@ -608,17 +619,17 @@ class HalfTest {
608619
assertTrue((Half.NEGATIVE_ZERO / Half.NaN).isNaN())
609620
assertTrue((Half.NaN / Half.NEGATIVE_ZERO).isNaN())
610621

611-
assertTrue((Half(2.0f) / Half.POSITIVE_INFINITY).isZero())
612-
assertTrue((Half.POSITIVE_INFINITY / Half(2.0f)).isInfinite())
622+
assertTrue((HALF_TWO / Half.POSITIVE_INFINITY).isZero())
623+
assertTrue((Half.POSITIVE_INFINITY / HALF_TWO).isInfinite())
613624

614-
assertTrue((Half(2.0f) / Half.NEGATIVE_INFINITY).isZero())
615-
assertTrue((Half.NEGATIVE_INFINITY / Half(2.0f)).isInfinite())
625+
assertTrue((HALF_TWO / Half.NEGATIVE_INFINITY).isZero())
626+
assertTrue((Half.NEGATIVE_INFINITY / HALF_TWO).isInfinite())
616627

617-
assertTrue((Half(2.0f) / Half.POSITIVE_ZERO).isInfinite())
618-
assertTrue((Half.POSITIVE_ZERO / Half(2.0f)).isZero())
628+
assertTrue((HALF_TWO / Half.POSITIVE_ZERO).isInfinite())
629+
assertTrue((Half.POSITIVE_ZERO / HALF_TWO).isZero())
619630

620-
assertTrue((Half(2.0f) / Half.NEGATIVE_ZERO).isInfinite())
621-
assertTrue((Half.NEGATIVE_ZERO / Half(2.0f)).isZero())
631+
assertTrue((HALF_TWO / Half.NEGATIVE_ZERO).isInfinite())
632+
assertTrue((Half.NEGATIVE_ZERO / HALF_TWO).isZero())
622633

623634
// Underflow
624635
assertEquals(Half.POSITIVE_ZERO, Half.MIN_VALUE / Half.MAX_VALUE)
@@ -628,7 +639,7 @@ class HalfTest {
628639
assertEquals(Half.POSITIVE_INFINITY, Half.MAX_VALUE / Half.MIN_VALUE)
629640
assertEquals(Half.NEGATIVE_INFINITY, (-Half.MAX_VALUE) / Half.MIN_VALUE)
630641

631-
assertEquals(Half(0.5f), Half(2.0f) / Half(4.0f))
642+
assertEquals(Half(0.5f), HALF_TWO / Half(4.0f))
632643
assertEquals(Half(0.5f), Half(1.2f) / Half(2.4f))
633644
assertEquals(Half(-0.5f), Half(1.2f) / Half(-2.4f))
634645
assertEquals(Half(-0.5f), Half(-1.2f) / Half(2.4f))
@@ -637,10 +648,86 @@ class HalfTest {
637648
assertEquals(Half(16_000.0f), Half(48_000.0f) / Half(3.0f))
638649
assertEquals(Half(-16_000.0f), Half(48_000.0f) / Half(-3.0f))
639650

640-
assertEquals(Half(2.0861626e-5f), Half(1.0f) / Half(48_000.0f))
641-
assertEquals(Half(-2.0861626e-5), Half(1.0f) / Half(-48_000.0f))
651+
assertEquals(Half(2.0861626e-5f), HALF_ONE / Half(48_000.0f))
652+
assertEquals(Half(-2.0861626e-5), HALF_ONE / Half(-48_000.0f))
642653

643654
assertEquals(Half(75.0f), Half(0.03f) / Half(0.0004f))
644655
assertEquals(Half(-75.0f), Half(0.03f) / Half(-0.0004f))
645656
}
657+
658+
@Test
659+
fun addition() {
660+
assertTrue((Half.NaN + Half.NaN).isNaN())
661+
662+
assertTrue((Half.NaN + HALF_ONE).isNaN())
663+
assertTrue((Half.NaN - HALF_ONE).isNaN())
664+
assertTrue((HALF_ONE + Half.NaN).isNaN())
665+
assertTrue((Half(-1.0f) + Half.NaN).isNaN())
666+
667+
assertTrue((Half.NaN + Half.POSITIVE_INFINITY).isNaN())
668+
assertTrue((Half.POSITIVE_INFINITY + Half.NaN).isNaN())
669+
assertTrue((Half.NaN + Half.NEGATIVE_INFINITY).isNaN())
670+
assertTrue((Half.NEGATIVE_INFINITY + Half.NaN).isNaN())
671+
672+
assertTrue((Half.NaN + Half.POSITIVE_ZERO).isNaN())
673+
assertTrue((Half.POSITIVE_ZERO + Half.NaN).isNaN())
674+
assertTrue((Half.NaN + Half.NEGATIVE_ZERO).isNaN())
675+
assertTrue((Half.NEGATIVE_ZERO + Half.NaN).isNaN())
676+
677+
assertTrue((Half.POSITIVE_INFINITY + HALF_ONE).isInfinite())
678+
assertTrue((Half.POSITIVE_INFINITY - HALF_ONE).isInfinite())
679+
assertTrue((HALF_ONE + Half.POSITIVE_INFINITY).isInfinite())
680+
assertTrue((HALF_ONE - Half.POSITIVE_INFINITY).isInfinite())
681+
assertTrue((Half.POSITIVE_INFINITY + Half.POSITIVE_INFINITY).isInfinite())
682+
assertTrue((Half.POSITIVE_INFINITY - Half.POSITIVE_INFINITY).isNaN())
683+
684+
assertTrue((Half.NEGATIVE_INFINITY - HALF_ONE).isInfinite())
685+
assertTrue((Half.NEGATIVE_INFINITY + HALF_ONE).isInfinite())
686+
assertTrue((HALF_ONE + Half.NEGATIVE_INFINITY).isInfinite())
687+
assertTrue((HALF_ONE - Half.NEGATIVE_INFINITY).isInfinite())
688+
assertTrue((Half.NEGATIVE_INFINITY + Half.NEGATIVE_INFINITY).isInfinite())
689+
assertTrue((Half.NEGATIVE_INFINITY - Half.NEGATIVE_INFINITY).isNaN())
690+
691+
assertEquals(Half(3.0f), HALF_ONE + HALF_TWO)
692+
693+
// Overflow
694+
assertEquals(Half.POSITIVE_INFINITY, Half(32768.0f) + Half(32768.0f))
695+
// Underflow
696+
assertEquals(Half.NEGATIVE_INFINITY, Half(-32768.0f) - Half(32768.0f))
697+
698+
for (i in 0x0..0xffff) {
699+
val v1 = Half(i.toUShort())
700+
if (v1.isFinite()) {
701+
assertTrue((v1 - v1).isZero())
702+
assertEquals(v1 * HALF_TWO, v1 + v1)
703+
}
704+
}
705+
}
706+
707+
@Test
708+
fun ulp() {
709+
assertTrue(Half.NaN.ulp.isNaN())
710+
711+
assertTrue(Half.POSITIVE_INFINITY.ulp.isInfinite())
712+
assertTrue(Half.NEGATIVE_INFINITY.ulp.isInfinite())
713+
714+
assertTrue((Half.MAX_VALUE + Half.MAX_VALUE.ulp).isInfinite())
715+
716+
assertEquals(Half.MIN_VALUE, Half.POSITIVE_ZERO.ulp)
717+
assertEquals(Half.MIN_VALUE, Half.NEGATIVE_ZERO.ulp)
718+
719+
assertEquals(HALF_ONE, Half(1024.0f).ulp)
720+
assertEquals(HALF_ONE, Half(-1024.0f).ulp)
721+
}
722+
723+
@Test
724+
fun sqrt() {
725+
for (i in 0x0..0xffff) {
726+
val v1 = Half(i.toUShort())
727+
if (v1.isFinite()) {
728+
val v2 = sqrt(v1)
729+
assertTrue(v1 - (v2 * v2) <= HALF_TWO * v1.ulp)
730+
}
731+
}
732+
}
646733
}

0 commit comments

Comments
 (0)