@@ -472,8 +472,16 @@ void test_distance_from_itself(void) {
472472 simsimd_i8_t i8s [1536 ];
473473 simsimd_u8_t u8s [1536 ];
474474 simsimd_b8_t b8s [1536 / 8 ];
475+ simsimd_e4m3_t e4m3s [1536 ];
476+ simsimd_e5m2_t e5m2s [1536 ];
475477 simsimd_distance_t distance [2 ];
476478
479+ // Initialize FP8 arrays with small values (avoid overflow)
480+ for (int i = 0 ; i < 1536 ; i ++ ) {
481+ e4m3s [i ] = 0x3C ; // ~1.0 in E4M3
482+ e5m2s [i ] = 0x3C ; // ~1.0 in E5M2
483+ }
484+
477485 // Angular distance
478486 simsimd_angular_i8 (i8s , i8s , 1536 , & distance [0 ]);
479487 simsimd_angular_u8 (u8s , u8s , 1536 , & distance [0 ]);
@@ -497,6 +505,8 @@ void test_distance_from_itself(void) {
497505 simsimd_dot_bf16 (bf16s , bf16s , 1536 , & distance [0 ]);
498506 simsimd_dot_f32 (f32s , f32s , 1536 , & distance [0 ]);
499507 simsimd_dot_f64 (f64s , f64s , 1536 , & distance [0 ]);
508+ simsimd_dot_e4m3 (e4m3s , e4m3s , 1536 , & distance [0 ]);
509+ simsimd_dot_e5m2 (e5m2s , e5m2s , 1536 , & distance [0 ]);
500510
501511 // Complex inner product
502512 simsimd_dot_bf16c (bf16cs , bf16cs , 768 , & distance [0 ]);
@@ -1390,6 +1400,183 @@ void test_geospatial_vincenty(void) {
13901400
13911401#pragma endregion
13921402
1403+ #pragma region Matrix Multiplication
1404+
1405+ /**
1406+ * @brief Reference BF16 matmul: C[m×n] = A[m×k] × B[n×k]ᵀ
1407+ */
1408+ static void reference_matmul_bf16_f32 (simsimd_bf16_t const * a , simsimd_bf16_t const * b , simsimd_f32_t * c ,
1409+ simsimd_size_t m , simsimd_size_t n , simsimd_size_t k ) {
1410+ for (simsimd_size_t i = 0 ; i < m ; i ++ ) {
1411+ for (simsimd_size_t j = 0 ; j < n ; j ++ ) {
1412+ simsimd_f32_t sum = 0 ;
1413+ for (simsimd_size_t kk = 0 ; kk < k ; kk ++ ) {
1414+ simsimd_f32_t a_val , b_val ;
1415+ simsimd_bf16_to_f32 (& a [i * k + kk ], & a_val );
1416+ simsimd_bf16_to_f32 (& b [j * k + kk ], & b_val );
1417+ sum += a_val * b_val ;
1418+ }
1419+ c [i * n + j ] = sum ;
1420+ }
1421+ }
1422+ }
1423+
1424+ /**
1425+ * @brief Reference I8 matmul: C[m×n] = A[m×k] × B[n×k]ᵀ
1426+ */
1427+ static void reference_matmul_i8_i32 (simsimd_i8_t const * a , simsimd_i8_t const * b , simsimd_i32_t * c ,
1428+ simsimd_size_t m , simsimd_size_t n , simsimd_size_t k ) {
1429+ for (simsimd_size_t i = 0 ; i < m ; i ++ ) {
1430+ for (simsimd_size_t j = 0 ; j < n ; j ++ ) {
1431+ simsimd_i32_t sum = 0 ;
1432+ for (simsimd_size_t kk = 0 ; kk < k ; kk ++ ) {
1433+ sum += (simsimd_i32_t )a [i * k + kk ] * (simsimd_i32_t )b [j * k + kk ];
1434+ }
1435+ c [i * n + j ] = sum ;
1436+ }
1437+ }
1438+ }
1439+
1440+ /**
1441+ * @brief Test BF16 matmul (pack + multiply) against reference.
1442+ */
1443+ void test_matmul_bf16 (void ) {
1444+ simsimd_size_t m = 64 , n = 64 , k = 64 ;
1445+
1446+ // Allocate matrices
1447+ simsimd_bf16_t * a = (simsimd_bf16_t * )malloc (m * k * sizeof (simsimd_bf16_t ));
1448+ simsimd_bf16_t * b = (simsimd_bf16_t * )malloc (n * k * sizeof (simsimd_bf16_t ));
1449+ simsimd_f32_t * c_ref = (simsimd_f32_t * )malloc (m * n * sizeof (simsimd_f32_t ));
1450+ simsimd_f32_t * c_test = (simsimd_f32_t * )malloc (m * n * sizeof (simsimd_f32_t ));
1451+
1452+ // Initialize with random values
1453+ for (simsimd_size_t i = 0 ; i < m * k ; i ++ ) {
1454+ simsimd_f32_t val = (simsimd_f32_t )test_random_f64 (-1.0 , 1.0 );
1455+ simsimd_f32_to_bf16 (& val , & a [i ]);
1456+ }
1457+ for (simsimd_size_t i = 0 ; i < n * k ; i ++ ) {
1458+ simsimd_f32_t val = (simsimd_f32_t )test_random_f64 (-1.0 , 1.0 );
1459+ simsimd_f32_to_bf16 (& val , & b [i ]);
1460+ }
1461+
1462+ // Compute reference
1463+ reference_matmul_bf16_f32 (a , b , c_ref , m , n , k );
1464+
1465+ // Test serial implementation
1466+ {
1467+ simsimd_size_t packed_size = simsimd_matmul_bf16_packed_size_serial (n , k );
1468+ void * b_packed = malloc (packed_size );
1469+ simsimd_matmul_bf16_pack_serial (b , n , k , k * sizeof (simsimd_bf16_t ), b_packed );
1470+ simsimd_matmul_bf16_f32_serial (a , b_packed , c_test , m , n , k ,
1471+ k * sizeof (simsimd_bf16_t ), n * sizeof (simsimd_f32_t ));
1472+ free (b_packed );
1473+
1474+ // Compare results with tolerance (BF16 has limited precision)
1475+ for (simsimd_size_t i = 0 ; i < m * n ; i ++ ) {
1476+ simsimd_f64_t diff = fabs ((simsimd_f64_t )c_ref [i ] - (simsimd_f64_t )c_test [i ]);
1477+ simsimd_f64_t rel_err = fabs (c_ref [i ]) > 1e-6 ? diff / fabs (c_ref [i ]) : diff ;
1478+ assert (rel_err < 0.02 || diff < 0.01 ); // 2% relative or 0.01 absolute
1479+ }
1480+ printf (" - serial: PASS\n" );
1481+ }
1482+
1483+ #if SIMSIMD_TARGET_SAPPHIRE
1484+ // Test Sapphire (AMX) implementation
1485+ {
1486+ simsimd_capability_t enabled = simsimd_enable_capabilities (simsimd_cap_sapphire_k );
1487+ if (enabled & simsimd_cap_sapphire_k ) {
1488+ simsimd_size_t packed_size = simsimd_matmul_bf16_packed_size_sapphire (n , k );
1489+ void * b_packed = malloc (packed_size );
1490+ simsimd_matmul_bf16_pack_sapphire (b , n , k , k * sizeof (simsimd_bf16_t ), b_packed );
1491+ simsimd_matmul_bf16_f32_sapphire (a , b_packed , c_test , m , n , k ,
1492+ k * sizeof (simsimd_bf16_t ), n * sizeof (simsimd_f32_t ));
1493+ free (b_packed );
1494+
1495+ // Compare results
1496+ for (simsimd_size_t i = 0 ; i < m * n ; i ++ ) {
1497+ simsimd_f64_t diff = fabs ((simsimd_f64_t )c_ref [i ] - (simsimd_f64_t )c_test [i ]);
1498+ simsimd_f64_t rel_err = fabs (c_ref [i ]) > 1e-6 ? diff / fabs (c_ref [i ]) : diff ;
1499+ assert (rel_err < 0.02 || diff < 0.01 );
1500+ }
1501+ printf (" - sapphire (AMX): PASS\n" );
1502+ } else {
1503+ printf (" - sapphire (AMX): SKIPPED (not available)\n" );
1504+ }
1505+ }
1506+ #endif
1507+
1508+ free (a );
1509+ free (b );
1510+ free (c_ref );
1511+ free (c_test );
1512+ printf ("Test matmul BF16: PASS\n" );
1513+ }
1514+
1515+ /**
1516+ * @brief Test I8 matmul (pack + multiply) against reference.
1517+ */
1518+ void test_matmul_i8 (void ) {
1519+ simsimd_size_t m = 64 , n = 64 , k = 64 ;
1520+
1521+ // Allocate matrices
1522+ simsimd_i8_t * a = (simsimd_i8_t * )malloc (m * k );
1523+ simsimd_i8_t * b = (simsimd_i8_t * )malloc (n * k );
1524+ simsimd_i32_t * c_ref = (simsimd_i32_t * )malloc (m * n * sizeof (simsimd_i32_t ));
1525+ simsimd_i32_t * c_test = (simsimd_i32_t * )malloc (m * n * sizeof (simsimd_i32_t ));
1526+
1527+ // Initialize with random values
1528+ for (simsimd_size_t i = 0 ; i < m * k ; i ++ )
1529+ a [i ] = (simsimd_i8_t )test_random_f64 (-127 , 127 );
1530+ for (simsimd_size_t i = 0 ; i < n * k ; i ++ )
1531+ b [i ] = (simsimd_i8_t )test_random_f64 (-127 , 127 );
1532+
1533+ // Compute reference
1534+ reference_matmul_i8_i32 (a , b , c_ref , m , n , k );
1535+
1536+ // Test serial implementation
1537+ {
1538+ simsimd_size_t packed_size = simsimd_matmul_i8_packed_size_serial (n , k );
1539+ void * b_packed = malloc (packed_size );
1540+ simsimd_matmul_i8_pack_serial (b , n , k , k , b_packed );
1541+ simsimd_matmul_i8_i32_serial (a , b_packed , c_test , m , n , k , k , n * sizeof (simsimd_i32_t ));
1542+ free (b_packed );
1543+
1544+ // Compare results (should be exact for integers)
1545+ for (simsimd_size_t i = 0 ; i < m * n ; i ++ )
1546+ assert (c_ref [i ] == c_test [i ]);
1547+ printf (" - serial: PASS\n" );
1548+ }
1549+
1550+ #if SIMSIMD_TARGET_SAPPHIRE
1551+ // Test Sapphire (AMX) implementation
1552+ {
1553+ simsimd_capability_t enabled = simsimd_enable_capabilities (simsimd_cap_sapphire_k );
1554+ if (enabled & simsimd_cap_sapphire_k ) {
1555+ simsimd_size_t packed_size = simsimd_matmul_i8_packed_size_sapphire (n , k );
1556+ void * b_packed = malloc (packed_size );
1557+ simsimd_matmul_i8_pack_sapphire (b , n , k , k , b_packed );
1558+ simsimd_matmul_i8_i32_sapphire (a , b_packed , c_test , m , n , k , k , n * sizeof (simsimd_i32_t ));
1559+ free (b_packed );
1560+
1561+ // Compare results
1562+ for (simsimd_size_t i = 0 ; i < m * n ; i ++ )
1563+ assert (c_ref [i ] == c_test [i ]);
1564+ printf (" - sapphire (AMX): PASS\n" );
1565+ } else {
1566+ printf (" - sapphire (AMX): SKIPPED (not available)\n" );
1567+ }
1568+ }
1569+ #endif
1570+
1571+ free (a );
1572+ free (b );
1573+ free (c_ref );
1574+ free (c_test );
1575+ printf ("Test matmul I8: PASS\n" );
1576+ }
1577+
1578+ #pragma endregion
1579+
13931580int main (int argc , char * * argv ) {
13941581 (void )argc ;
13951582 (void )argv ;
@@ -1426,6 +1613,10 @@ int main(int argc, char **argv) {
14261613 test_geospatial_haversine ();
14271614 test_geospatial_vincenty ();
14281615
1616+ // Matrix multiplication
1617+ test_matmul_bf16 ();
1618+ test_matmul_i8 ();
1619+
14291620 printf ("All tests passed.\n" );
14301621 return 0 ;
14311622}
0 commit comments