Skip to content

Commit 062f390

Browse files
committed
Fix: Missing Python APIs for probability dists
In one of the previous commits the `kullbackleibler` and `jensenshannon` interfaces were removed.
1 parent 1f228c0 commit 062f390

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

python/bench.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ def raise_(ex):
212212
[np.float64, np.float32, np.float16],
213213
np.array,
214214
),
215-
("scipy.hamming", spd.hamming, simd.hamming, [np.uint8], np.array),
216-
("scipy.jaccard", spd.jaccard, simd.jaccard, [np.uint8], np.array),
215+
("scipy.hamming", spd.hamming, lambda a, b: simd.hamming(a, b, "b8"), [np.uint8], np.array),
216+
("scipy.jaccard", spd.jaccard, lambda a, b: simd.jaccard(a, b, "b8"), [np.uint8], np.array),
217217
]
218218
)
219219

@@ -348,8 +348,8 @@ def raise_(ex):
348348
simd.kullbackleibler,
349349
[np.float64, np.float32, np.float16],
350350
),
351-
("scipy.hamming", wrap_rowwise(spd.hamming), simd.hamming, [np.uint8]),
352-
("scipy.jaccard", wrap_rowwise(spd.jaccard), simd.jaccard, [np.uint8]),
351+
("scipy.hamming", wrap_rowwise(spd.hamming), lambda a, b: simd.hamming(a, b, "b8"), [np.uint8]),
352+
("scipy.jaccard", wrap_rowwise(spd.jaccard), lambda a, b: simd.jaccard(a, b, "b8"), [np.uint8]),
353353
]
354354
)
355355

python/lib.c

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,36 @@ static PyMethodDef simsimd_methods[] = {
12361236
"Notes:\n"
12371237
" * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.",
12381238
},
1239+
{
1240+
"kullbackleibler",
1241+
(PyCFunction)api_kl,
1242+
METH_FASTCALL,
1243+
"Compute Kullback-Leibler divergences between two matrices.\n\n"
1244+
"Args:\n"
1245+
" a (NDArray): First floating-point matrix or vector.\n"
1246+
" b (NDArray): Second floating-point matrix or vector.\n"
1247+
" dtype (IntegralType, optional): Override the presumed input type.\n\n"
1248+
"Returns:\n"
1249+
" DistancesTensor: The Kullback-Leibler divergences distances.\n\n"
1250+
"Equivalent to: `scipy.special.kl_div`.\n"
1251+
"Notes:\n"
1252+
" * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.",
1253+
},
1254+
{
1255+
"jensenshannon",
1256+
(PyCFunction)api_js,
1257+
METH_FASTCALL,
1258+
"Compute Jensen-Shannon divergences between two matrices.\n\n"
1259+
"Args:\n"
1260+
" a (NDArray): First floating-point matrix or vector.\n"
1261+
" b (NDArray): Second floating-point matrix or vector.\n"
1262+
" dtype (IntegralType, optional): Override the presumed input type.\n\n"
1263+
"Returns:\n"
1264+
" DistancesTensor: The Jensen-Shannon divergences distances.\n\n"
1265+
"Equivalent to: `scipy.spatial.distance.jensenshannon`.\n"
1266+
"Notes:\n"
1267+
" * `a` and `b` are positional-only arguments, while `dtype` is a keyword-only argument.",
1268+
},
12391269

12401270
// Conventional `cdist` interface for pairwise distances
12411271
{

0 commit comments

Comments
 (0)