Skip to content

Commit d3aef24

Browse files
authored
Merge pull request #217 from ashvardanian/main-dev
v6: Future-Proofing Dense & Sparse Operations
2 parents fac489d + 2565d6f commit d3aef24

File tree

14 files changed

+423
-326
lines changed

14 files changed

+423
-326
lines changed

.vscode/settings.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@
9393
"format": "c",
9494
"execution": "cpp",
9595
"math.h": "c",
96-
"float.h": "c"
96+
"float.h": "c",
97+
"text_encoding": "cpp",
98+
"stdio.h": "c"
9799
},
98100
"cSpell.words": [
99101
"allclose",

CONTRIBUTING.md

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ You can also benchmark against other libraries, filter the numeric types, and di
101101
$ python scripts/bench_vectors.py --help
102102
> usage: bench.py [-h] [--ndim NDIM] [-n COUNT]
103103
> [--metric {all,dot,spatial,binary,probability,sparse}]
104-
> [--dtype {all,bits,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}]
104+
> [--dtype {all,bin8,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}]
105105
> [--scipy] [--scikit] [--torch] [--tf] [--jax]
106106
>
107107
> Benchmark SimSIMD vs. other libraries
@@ -119,7 +119,7 @@ $ python scripts/bench_vectors.py --help
119119
> `cdist`.
120120
> --metric {all,dot,spatial,binary,probability,sparse}
121121
> Distance metric to use, profiles everything by default
122-
> --dtype {all,bits,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}
122+
> --dtype {all,bin8,int8,uint16,uint32,float16,float32,float64,bfloat16,complex32,complex64,complex128}
123123
> Defines numeric types to benchmark, profiles everything by default
124124
> --scipy Profile SciPy, must be installed
125125
> --scikit Profile scikit-learn, must be installed
@@ -203,6 +203,35 @@ bun test
203203
swift build && swift test -v
204204
```
205205
206+
Running Swift on Linux requires a couple of extra steps, as the Swift compiler is not available in the default repositories.
207+
Please get the most recent Swift tarball from the [official website](https://www.swift.org/install/).
208+
At the time of writing, for 64-bit Arm CPU running Ubuntu 22.04, the following commands would work:
209+
210+
```bash
211+
wget https://download.swift.org/swift-5.9.2-release/ubuntu2204-aarch64/swift-5.9.2-RELEASE/swift-5.9.2-RELEASE-ubuntu22.04-aarch64.tar.gz
212+
tar xzf swift-5.9.2-RELEASE-ubuntu22.04-aarch64.tar.gz
213+
sudo mv swift-5.9.2-RELEASE-ubuntu22.04-aarch64 /usr/share/swift
214+
echo "export PATH=/usr/share/swift/usr/bin:$PATH" >> ~/.bashrc
215+
source ~/.bashrc
216+
```
217+
218+
You can check the available images on [`swift.org/download` page](https://www.swift.org/download/#releases).
219+
For x86 CPUs, the following commands would work:
220+
221+
```bash
222+
wget https://download.swift.org/swift-5.9.2-release/ubuntu2204/swift-5.9.2-RELEASE/swift-5.9.2-RELEASE-ubuntu22.04.tar.gz
223+
tar xzf swift-5.9.2-RELEASE-ubuntu22.04.tar.gz
224+
sudo mv swift-5.9.2-RELEASE-ubuntu22.04 /usr/share/swift
225+
echo "export PATH=/usr/share/swift/usr/bin:$PATH" >> ~/.bashrc
226+
source ~/.bashrc
227+
```
228+
229+
Alternatively, on Linux, the official Swift Docker image can be used for builds and tests:
230+
231+
```bash
232+
sudo docker run --rm -v "$PWD:/workspace" -w /workspace swift:5.9 /bin/bash -cl "swift build -c release --static-swift-stdlib && swift test -c release --enable-test-discovery"
233+
```
234+
206235
## GoLang
207236
208237
```sh

README.md

Lines changed: 90 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ Implemented distance functions include:
6969

7070
Moreover, SimSIMD...
7171

72-
- handles `f64`, `f32`, `f16`, and `bf16` real & complex vectors.
73-
- handles `i8` integral, `i4` sub-byte, and `b8` binary vectors.
74-
- handles sparse `u32` and `u16` sets, and weighted sparse vectors.
72+
- handles `float64`, `float32`, `float16`, and `bfloat16` real & complex vectors.
73+
- handles `int8` integral, `int4` sub-byte, and `b8` binary vectors.
74+
- handles sparse `uint32` and `uint16` sets, and weighted sparse vectors.
7575
- is a zero-dependency [header-only C 99](#using-simsimd-in-c) library.
7676
- has [Python](#using-simsimd-in-python), [Rust](#using-simsimd-in-rust), [JS](#using-simsimd-in-javascript), and [Swift](#using-simsimd-in-swift) bindings.
7777
- has Arm backends for NEON, Scalable Vector Extensions (SVE), and SVE2.
@@ -95,14 +95,14 @@ You can learn more about the technical implementation details in the following b
9595
For reference, we use 1536-dimensional vectors, like the embeddings produced by the OpenAI Ada API.
9696
Comparing the serial code throughput produced by GCC 12 to hand-optimized kernels in SimSIMD, we see the following single-core improvements for the two most common vector-vector similarity metrics - the Cosine similarity and the Euclidean distance:
9797

98-
| Type | Apple M2 Pro | Intel Sapphire Rapids | AWS Graviton 4 |
99-
| :----- | ----------------------------: | -------------------------------: | ------------------------------: |
100-
| `f64` | 18.5 → 28.8 GB/s <br/> + 56 % | 21.9 → 41.4 GB/s <br/> + 89 % | 20.7 → 41.3 GB/s <br/> + 99 % |
101-
| `f32` | 9.2 → 29.6 GB/s <br/> + 221 % | 10.9 → 95.8 GB/s <br/> + 779 % | 4.9 → 41.9 GB/s <br/> + 755 % |
102-
| `f16` | 4.6 → 14.6 GB/s <br/> + 217 % | 3.1 → 108.4 GB/s <br/> + 3,397 % | 5.4 → 39.3 GB/s <br/> + 627 % |
103-
| `bf16` | 4.6 → 26.3 GB/s <br/> + 472 % | 0.8 → 59.5 GB/s <br/> +7,437 % | 2.5 → 29.9 GB/s <br/> + 1,096 % |
104-
| `i8` | 25.8 → 47.1 GB/s <br/> + 83 % | 33.1 → 65.3 GB/s <br/> + 97 % | 35.2 → 43.5 GB/s <br/> + 24 % |
105-
| `u8` | | 32.5 → 66.5 GB/s <br/> + 105 % | |
98+
| Type | Apple M2 Pro | Intel Sapphire Rapids | AWS Graviton 4 |
99+
| :--------- | ----------------------------: | -------------------------------: | ------------------------------: |
100+
| `float64` | 18.5 → 28.8 GB/s <br/> + 56 % | 21.9 → 41.4 GB/s <br/> + 89 % | 20.7 → 41.3 GB/s <br/> + 99 % |
101+
| `float32` | 9.2 → 29.6 GB/s <br/> + 221 % | 10.9 → 95.8 GB/s <br/> + 779 % | 4.9 → 41.9 GB/s <br/> + 755 % |
102+
| `float16` | 4.6 → 14.6 GB/s <br/> + 217 % | 3.1 → 108.4 GB/s <br/> + 3,397 % | 5.4 → 39.3 GB/s <br/> + 627 % |
103+
| `bfloat16` | 4.6 → 26.3 GB/s <br/> + 472 % | 0.8 → 59.5 GB/s <br/> +7,437 % | 2.5 → 29.9 GB/s <br/> + 1,096 % |
104+
| `int8` | 25.8 → 47.1 GB/s <br/> + 83 % | 33.1 → 65.3 GB/s <br/> + 97 % | 35.2 → 43.5 GB/s <br/> + 24 % |
105+
| `uint8` | | 32.5 → 66.5 GB/s <br/> + 105 % | |
106106

107107
Similar speedups are often observed even when compared to BLAS and LAPACK libraries underlying most numerical computing libraries, including NumPy and SciPy in Python.
108108
Broader benchmarking results:
@@ -115,8 +115,8 @@ Broader benchmarking results:
115115

116116
The package is intended to replace the usage of `numpy.inner`, `numpy.dot`, and `scipy.spatial.distance`.
117117
Aside from drastic performance improvements, SimSIMD significantly improves accuracy in mixed precision setups.
118-
NumPy and SciPy, processing `i8`, `u8` or `f16` vectors, will use the same types for accumulators, while SimSIMD can combine `i8` enumeration, `i16` multiplication, and `i32` accumulation to avoid overflows entirely.
119-
The same applies to processing `f16` and `bf16` values with `f32` precision.
118+
NumPy and SciPy, processing `int8`, `uint8` or `float16` vectors, will use the same types for accumulators, while SimSIMD can combine `int8` enumeration, `int16` multiplication, and `int32` accumulation to avoid overflows entirely.
119+
The same applies to processing `float16` and `bfloat16` values with `float32` precision.
120120

121121
### Installation
122122

@@ -155,14 +155,33 @@ dist = simsimd.vdot(vec1.astype(np.complex64), vec2.astype(np.complex64)) # conj
155155
```
156156

157157
Unlike SciPy, SimSIMD allows explicitly stating the precision of the input vectors, which is especially useful for mixed-precision setups.
158+
The `dtype` argument can be passed both by name and as a positional argument:
158159

159160
```py
160-
dist = simsimd.cosine(vec1, vec2, "i8")
161-
dist = simsimd.cosine(vec1, vec2, "f16")
162-
dist = simsimd.cosine(vec1, vec2, "f32")
163-
dist = simsimd.cosine(vec1, vec2, "f64")
164-
dist = simsimd.hamming(vec1, vec2, "bits")
165-
dist = simsimd.jaccard(vec1, vec2, "bits")
161+
dist = simsimd.cosine(vec1, vec2, "int8")
162+
dist = simsimd.cosine(vec1, vec2, "float16")
163+
dist = simsimd.cosine(vec1, vec2, "float32")
164+
dist = simsimd.cosine(vec1, vec2, "float64")
165+
dist = simsimd.hamming(vec1, vec2, "bit8")
166+
```
167+
168+
With other frameworks, like PyTorch, one can get a richer type-system than NumPy, but the lack of good CPython interoperability makes it hard to pass data without copies.
169+
170+
```py
171+
import numpy as np
172+
buf1 = np.empty(8, dtype=np.uint16)
173+
buf2 = np.empty(8, dtype=np.uint16)
174+
175+
# View the same memory region with PyTorch and randomize it
176+
import torch
177+
vec1 = torch.asarray(memoryview(buf1), copy=False).view(torch.bfloat16)
178+
vec2 = torch.asarray(memoryview(buf2), copy=False).view(torch.bfloat16)
179+
torch.randn(8, out=vec1)
180+
torch.randn(8, out=vec2)
181+
182+
# Both libs will look into the same memory buffers and report the same results
183+
dist_slow = 1 - torch.nn.functional.cosine_similarity(vec1, vec2, dim=0)
184+
dist_fast = simsimd.cosine(buf1, buf2, "bf16")
166185
```
167186

168187
It also allows using SimSIMD for half-precision complex numbers, which NumPy does not support.
@@ -235,6 +254,48 @@ distances: DistancesTensor = simsimd.cdist(matrix1, matrix2, metric="cosine")
235254
distances_array: np.ndarray = np.array(distances, copy=True) # now managed by NumPy
236255
```
237256

257+
### Elementwise Kernels
258+
259+
SimSIMD also provides mixed-precision elementwise kernels, where the input vectors and the output have the same numeric type, but the intermediate accumulators are of a higher precision.
260+
261+
```py
262+
import numpy as np
263+
from simsimd import fma, wsum
264+
265+
# Let's take two FullHD video frames
266+
first_frame = np.random.randn(1920 * 1024).astype(np.uint8)
267+
second_frame = np.random.randn(1920 * 1024).astype(np.uint8)
268+
average_frame = np.empty_like(first_frame)
269+
wsum(first_frame, second_frame, alpha=0.5, beta=0.5, out=average_frame)
270+
271+
# Slow analog with NumPy:
272+
slow_average_frame = (0.5 * first_frame + 0.5 * second_frame).astype(np.uint8)
273+
```
274+
275+
Similarly, the `fma` takes three arguments and computes the fused multiply-add operation.
276+
In applications like Machine Learning you may also benefit from using the "brain-float" format not natively supported by NumPy.
277+
In 3D Graphics, for example, we can use FMA to compute the [Phong shading model](https://en.wikipedia.org/wiki/Phong_shading):
278+
279+
```py
280+
# Assume a FullHD frame with random values for simplicity
281+
light_intensity = np.random.rand(1920 * 1080).astype(np.float16) # Intensity of light on each pixel
282+
diffuse_component = np.random.rand(1920 * 1080).astype(np.float16) # Diffuse reflectance on the surface
283+
specular_component = np.random.rand(1920 * 1080).astype(np.float16) # Specular reflectance for highlights
284+
output_color = np.empty_like(light_intensity) # Array to store the resulting color intensity
285+
286+
# Define the scaling factors for diffuse and specular contributions
287+
alpha = 0.7 # Weight for the diffuse component
288+
beta = 0.3 # Weight for the specular component
289+
290+
# Formula: color = alpha * light_intensity * diffuse_component + beta * specular_component
291+
fma(light_intensity, diffuse_component, specular_component,
292+
dtype="float16", # Optional, unless it can't be inferred from the input
293+
alpha=alpha, beta=beta, out=output_color)
294+
295+
# Slow analog with NumPy for comparison
296+
slow_output_color = (alpha * light_intensity * diffuse_component + beta * specular_component).astype(np.float16)
297+
```
298+
238299
### Multithreading and Memory Usage
239300

240301
By default, computations use a single CPU core.
@@ -248,15 +309,15 @@ matrix1 = np.packbits(np.random.randint(2, size=(10_000, ndim)).astype(np.uint8)
248309
matrix2 = np.packbits(np.random.randint(2, size=(1_000, ndim)).astype(np.uint8))
249310

250311
distances = simsimd.cdist(matrix1, matrix2,
251-
metric="hamming", # Unlike SciPy, SimSIMD doesn't divide by the number of dimensions
252-
out_dtype="u8", # so we can use `u8` instead of `f64` to save memory.
253-
threads=0, # Use all CPU cores with OpenMP.
254-
dtype="b8", # Override input argument type to `b8` eight-bit words.
312+
metric="hamming", # Unlike SciPy, SimSIMD doesn't divide by the number of dimensions
313+
out_dtype="uint8", # so we can use `uint8` instead of `float64` to save memory.
314+
threads=0, # Use all CPU cores with OpenMP.
315+
dtype="bin8", # Override input argument type to `bin8` eight-bit words.
255316
)
256317
```
257318

258-
By default, the output distances will be stored in double-precision `f64` floating-point numbers.
259-
That behavior may not be space-efficient, especially if you are computing the hamming distance between short binary vectors, that will generally fit into 8x smaller `u8` or `u16` types.
319+
By default, the output distances will be stored in double-precision `float64` floating-point numbers.
320+
That behavior may not be space-efficient, especially if you are computing the hamming distance between short binary vectors, that will generally fit into 8x smaller `uint8` or `uint16` types.
260321
To override this behavior, use the `dtype` argument.
261322

262323
### Helper Functions
@@ -575,7 +636,7 @@ Simplest of all, you can include the headers, and the compiler will automaticall
575636
int main() {
576637
simsimd_f32_t vector_a[1536];
577638
simsimd_f32_t vector_b[1536];
578-
simsimd_metric_punned_t distance_function = simsimd_metric_punned(
639+
simsimd_kernel_punned_t distance_function = simsimd_metric_punned(
579640
simsimd_metric_cos_k, // Metric kind, like the angular cosine distance
580641
simsimd_datatype_f32_k, // Data type, like: f16, f32, f64, i8, b8, and complex variants
581642
simsimd_cap_any_k); // Which CPU capabilities are we allowed to use
@@ -663,7 +724,6 @@ int main() {
663724
simsimd_vdot_f16c(f16s, f16s, 1536, &distance);
664725
simsimd_vdot_f32c(f32s, f32s, 1536, &distance);
665726
simsimd_vdot_f64c(f64s, f64s, 1536, &distance);
666-
667727
return 0;
668728
}
669729
```
@@ -676,13 +736,8 @@ int main() {
676736
int main() {
677737
simsimd_b8_t b8s[1536 / 8]; // 8 bits per word
678738
simsimd_distance_t distance;
679-
680-
// Hamming distance between two vectors
681739
simsimd_hamming_b8(b8s, b8s, 1536 / 8, &distance);
682-
683-
// Jaccard distance between two vectors
684740
simsimd_jaccard_b8(b8s, b8s, 1536 / 8, &distance);
685-
686741
return 0;
687742
}
688743
```
@@ -707,7 +762,6 @@ int main() {
707762
simsimd_kl_f16(f16s, f16s, 1536, &distance);
708763
simsimd_kl_f32(f32s, f32s, 1536, &distance);
709764
simsimd_kl_f64(f64s, f64s, 1536, &distance);
710-
711765
return 0;
712766
}
713767
```
@@ -949,10 +1003,10 @@ In NumPy terms, the implementation may look like:
9491003

9501004
```py
9511005
import numpy as np
952-
def wsum(A: np.ndarray, B: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
1006+
def wsum(A: np.ndarray, B: np.ndarray, /, Alpha: float, Beta: float) -> np.ndarray:
9531007
assert A.dtype == B.dtype, "Input types must match and affect the output style"
9541008
return (Alpha * A + Beta * B).astype(A.dtype)
955-
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, Alpha: float, Beta: float) -> np.ndarray:
1009+
def fma(A: np.ndarray, B: np.ndarray, C: np.ndarray, /, Alpha: float, Beta: float) -> np.ndarray:
9561010
assert A.dtype == B.dtype and A.dtype == C.dtype, "Input types must match and affect the output style"
9571011
return (Alpha * A * B + Beta * C).astype(A.dtype)
9581012
```
@@ -1095,7 +1149,7 @@ All of the function names follow the same pattern: `simsimd_{function}_{type}_{b
10951149
- The type can be `f64`, `f32`, `f16`, `bf16`, `f64c`, `f32c`, `f16c`, `bf16c`, `i8`, or `b8`.
10961150
- The function can be `dot`, `vdot`, `cos`, `l2sq`, `hamming`, `jaccard`, `kl`, `js`, or `intersect`.
10971151

1098-
To avoid hard-coding the backend, you can use the `simsimd_metric_punned_t` to pun the function pointer and the `simsimd_capabilities` function to get the available backends at runtime.
1152+
To avoid hard-coding the backend, you can use the `simsimd_kernel_punned_t` to pun the function pointer and the `simsimd_capabilities` function to get the available backends at runtime.
10991153
To match all the function names, consider a RegEx:
11001154

11011155
```regex

0 commit comments

Comments
 (0)