Skip to content

Commit 3c4f674

Browse files
committed
Improve: Test matmuls
1 parent 7b8787f commit 3c4f674

File tree

1 file changed

+191
-0
lines changed

1 file changed

+191
-0
lines changed

scripts/test.c

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,16 @@ void test_distance_from_itself(void) {
472472
simsimd_i8_t i8s[1536];
473473
simsimd_u8_t u8s[1536];
474474
simsimd_b8_t b8s[1536 / 8];
475+
simsimd_e4m3_t e4m3s[1536];
476+
simsimd_e5m2_t e5m2s[1536];
475477
simsimd_distance_t distance[2];
476478

479+
// Initialize FP8 arrays with small values (avoid overflow)
480+
for (int i = 0; i < 1536; i++) {
481+
e4m3s[i] = 0x3C; // ~1.0 in E4M3
482+
e5m2s[i] = 0x3C; // ~1.0 in E5M2
483+
}
484+
477485
// Angular distance
478486
simsimd_angular_i8(i8s, i8s, 1536, &distance[0]);
479487
simsimd_angular_u8(u8s, u8s, 1536, &distance[0]);
@@ -497,6 +505,8 @@ void test_distance_from_itself(void) {
497505
simsimd_dot_bf16(bf16s, bf16s, 1536, &distance[0]);
498506
simsimd_dot_f32(f32s, f32s, 1536, &distance[0]);
499507
simsimd_dot_f64(f64s, f64s, 1536, &distance[0]);
508+
simsimd_dot_e4m3(e4m3s, e4m3s, 1536, &distance[0]);
509+
simsimd_dot_e5m2(e5m2s, e5m2s, 1536, &distance[0]);
500510

501511
// Complex inner product
502512
simsimd_dot_bf16c(bf16cs, bf16cs, 768, &distance[0]);
@@ -1390,6 +1400,183 @@ void test_geospatial_vincenty(void) {
13901400

13911401
#pragma endregion
13921402

1403+
#pragma region Matrix Multiplication
1404+
1405+
/**
1406+
* @brief Reference BF16 matmul: C[m×n] = A[m×k] × B[n×k]ᵀ
1407+
*/
1408+
static void reference_matmul_bf16_f32(simsimd_bf16_t const *a, simsimd_bf16_t const *b, simsimd_f32_t *c,
1409+
simsimd_size_t m, simsimd_size_t n, simsimd_size_t k) {
1410+
for (simsimd_size_t i = 0; i < m; i++) {
1411+
for (simsimd_size_t j = 0; j < n; j++) {
1412+
simsimd_f32_t sum = 0;
1413+
for (simsimd_size_t kk = 0; kk < k; kk++) {
1414+
simsimd_f32_t a_val, b_val;
1415+
simsimd_bf16_to_f32(&a[i * k + kk], &a_val);
1416+
simsimd_bf16_to_f32(&b[j * k + kk], &b_val);
1417+
sum += a_val * b_val;
1418+
}
1419+
c[i * n + j] = sum;
1420+
}
1421+
}
1422+
}
1423+
1424+
/**
1425+
* @brief Reference I8 matmul: C[m×n] = A[m×k] × B[n×k]ᵀ
1426+
*/
1427+
static void reference_matmul_i8_i32(simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i32_t *c,
1428+
simsimd_size_t m, simsimd_size_t n, simsimd_size_t k) {
1429+
for (simsimd_size_t i = 0; i < m; i++) {
1430+
for (simsimd_size_t j = 0; j < n; j++) {
1431+
simsimd_i32_t sum = 0;
1432+
for (simsimd_size_t kk = 0; kk < k; kk++) {
1433+
sum += (simsimd_i32_t)a[i * k + kk] * (simsimd_i32_t)b[j * k + kk];
1434+
}
1435+
c[i * n + j] = sum;
1436+
}
1437+
}
1438+
}
1439+
1440+
/**
1441+
* @brief Test BF16 matmul (pack + multiply) against reference.
1442+
*/
1443+
void test_matmul_bf16(void) {
1444+
simsimd_size_t m = 64, n = 64, k = 64;
1445+
1446+
// Allocate matrices
1447+
simsimd_bf16_t *a = (simsimd_bf16_t *)malloc(m * k * sizeof(simsimd_bf16_t));
1448+
simsimd_bf16_t *b = (simsimd_bf16_t *)malloc(n * k * sizeof(simsimd_bf16_t));
1449+
simsimd_f32_t *c_ref = (simsimd_f32_t *)malloc(m * n * sizeof(simsimd_f32_t));
1450+
simsimd_f32_t *c_test = (simsimd_f32_t *)malloc(m * n * sizeof(simsimd_f32_t));
1451+
1452+
// Initialize with random values
1453+
for (simsimd_size_t i = 0; i < m * k; i++) {
1454+
simsimd_f32_t val = (simsimd_f32_t)test_random_f64(-1.0, 1.0);
1455+
simsimd_f32_to_bf16(&val, &a[i]);
1456+
}
1457+
for (simsimd_size_t i = 0; i < n * k; i++) {
1458+
simsimd_f32_t val = (simsimd_f32_t)test_random_f64(-1.0, 1.0);
1459+
simsimd_f32_to_bf16(&val, &b[i]);
1460+
}
1461+
1462+
// Compute reference
1463+
reference_matmul_bf16_f32(a, b, c_ref, m, n, k);
1464+
1465+
// Test serial implementation
1466+
{
1467+
simsimd_size_t packed_size = simsimd_matmul_bf16_packed_size_serial(n, k);
1468+
void *b_packed = malloc(packed_size);
1469+
simsimd_matmul_bf16_pack_serial(b, n, k, k * sizeof(simsimd_bf16_t), b_packed);
1470+
simsimd_matmul_bf16_f32_serial(a, b_packed, c_test, m, n, k,
1471+
k * sizeof(simsimd_bf16_t), n * sizeof(simsimd_f32_t));
1472+
free(b_packed);
1473+
1474+
// Compare results with tolerance (BF16 has limited precision)
1475+
for (simsimd_size_t i = 0; i < m * n; i++) {
1476+
simsimd_f64_t diff = fabs((simsimd_f64_t)c_ref[i] - (simsimd_f64_t)c_test[i]);
1477+
simsimd_f64_t rel_err = fabs(c_ref[i]) > 1e-6 ? diff / fabs(c_ref[i]) : diff;
1478+
assert(rel_err < 0.02 || diff < 0.01); // 2% relative or 0.01 absolute
1479+
}
1480+
printf(" - serial: PASS\n");
1481+
}
1482+
1483+
#if SIMSIMD_TARGET_SAPPHIRE
1484+
// Test Sapphire (AMX) implementation
1485+
{
1486+
simsimd_capability_t enabled = simsimd_enable_capabilities(simsimd_cap_sapphire_k);
1487+
if (enabled & simsimd_cap_sapphire_k) {
1488+
simsimd_size_t packed_size = simsimd_matmul_bf16_packed_size_sapphire(n, k);
1489+
void *b_packed = malloc(packed_size);
1490+
simsimd_matmul_bf16_pack_sapphire(b, n, k, k * sizeof(simsimd_bf16_t), b_packed);
1491+
simsimd_matmul_bf16_f32_sapphire(a, b_packed, c_test, m, n, k,
1492+
k * sizeof(simsimd_bf16_t), n * sizeof(simsimd_f32_t));
1493+
free(b_packed);
1494+
1495+
// Compare results
1496+
for (simsimd_size_t i = 0; i < m * n; i++) {
1497+
simsimd_f64_t diff = fabs((simsimd_f64_t)c_ref[i] - (simsimd_f64_t)c_test[i]);
1498+
simsimd_f64_t rel_err = fabs(c_ref[i]) > 1e-6 ? diff / fabs(c_ref[i]) : diff;
1499+
assert(rel_err < 0.02 || diff < 0.01);
1500+
}
1501+
printf(" - sapphire (AMX): PASS\n");
1502+
} else {
1503+
printf(" - sapphire (AMX): SKIPPED (not available)\n");
1504+
}
1505+
}
1506+
#endif
1507+
1508+
free(a);
1509+
free(b);
1510+
free(c_ref);
1511+
free(c_test);
1512+
printf("Test matmul BF16: PASS\n");
1513+
}
1514+
1515+
/**
1516+
* @brief Test I8 matmul (pack + multiply) against reference.
1517+
*/
1518+
void test_matmul_i8(void) {
1519+
simsimd_size_t m = 64, n = 64, k = 64;
1520+
1521+
// Allocate matrices
1522+
simsimd_i8_t *a = (simsimd_i8_t *)malloc(m * k);
1523+
simsimd_i8_t *b = (simsimd_i8_t *)malloc(n * k);
1524+
simsimd_i32_t *c_ref = (simsimd_i32_t *)malloc(m * n * sizeof(simsimd_i32_t));
1525+
simsimd_i32_t *c_test = (simsimd_i32_t *)malloc(m * n * sizeof(simsimd_i32_t));
1526+
1527+
// Initialize with random values
1528+
for (simsimd_size_t i = 0; i < m * k; i++)
1529+
a[i] = (simsimd_i8_t)test_random_f64(-127, 127);
1530+
for (simsimd_size_t i = 0; i < n * k; i++)
1531+
b[i] = (simsimd_i8_t)test_random_f64(-127, 127);
1532+
1533+
// Compute reference
1534+
reference_matmul_i8_i32(a, b, c_ref, m, n, k);
1535+
1536+
// Test serial implementation
1537+
{
1538+
simsimd_size_t packed_size = simsimd_matmul_i8_packed_size_serial(n, k);
1539+
void *b_packed = malloc(packed_size);
1540+
simsimd_matmul_i8_pack_serial(b, n, k, k, b_packed);
1541+
simsimd_matmul_i8_i32_serial(a, b_packed, c_test, m, n, k, k, n * sizeof(simsimd_i32_t));
1542+
free(b_packed);
1543+
1544+
// Compare results (should be exact for integers)
1545+
for (simsimd_size_t i = 0; i < m * n; i++)
1546+
assert(c_ref[i] == c_test[i]);
1547+
printf(" - serial: PASS\n");
1548+
}
1549+
1550+
#if SIMSIMD_TARGET_SAPPHIRE
1551+
// Test Sapphire (AMX) implementation
1552+
{
1553+
simsimd_capability_t enabled = simsimd_enable_capabilities(simsimd_cap_sapphire_k);
1554+
if (enabled & simsimd_cap_sapphire_k) {
1555+
simsimd_size_t packed_size = simsimd_matmul_i8_packed_size_sapphire(n, k);
1556+
void *b_packed = malloc(packed_size);
1557+
simsimd_matmul_i8_pack_sapphire(b, n, k, k, b_packed);
1558+
simsimd_matmul_i8_i32_sapphire(a, b_packed, c_test, m, n, k, k, n * sizeof(simsimd_i32_t));
1559+
free(b_packed);
1560+
1561+
// Compare results
1562+
for (simsimd_size_t i = 0; i < m * n; i++)
1563+
assert(c_ref[i] == c_test[i]);
1564+
printf(" - sapphire (AMX): PASS\n");
1565+
} else {
1566+
printf(" - sapphire (AMX): SKIPPED (not available)\n");
1567+
}
1568+
}
1569+
#endif
1570+
1571+
free(a);
1572+
free(b);
1573+
free(c_ref);
1574+
free(c_test);
1575+
printf("Test matmul I8: PASS\n");
1576+
}
1577+
1578+
#pragma endregion
1579+
13931580
int main(int argc, char **argv) {
13941581
(void)argc;
13951582
(void)argv;
@@ -1426,6 +1613,10 @@ int main(int argc, char **argv) {
14261613
test_geospatial_haversine();
14271614
test_geospatial_vincenty();
14281615

1616+
// Matrix multiplication
1617+
test_matmul_bf16();
1618+
test_matmul_i8();
1619+
14291620
printf("All tests passed.\n");
14301621
return 0;
14311622
}

0 commit comments

Comments
 (0)