Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 143 additions & 30 deletions faiss/utils/simd_impl/distances_arm_sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,22 @@ struct ElementOpIP {
}
};

struct ElementOpL2 {
static svfloat32_t op(svbool_t pg, svfloat32_t x, svfloat32_t y) {
const svfloat32_t diff = svsub_f32_x(pg, x, y);
return svmul_f32_x(pg, diff, diff);
}

static svfloat32_t merge(
svbool_t pg,
svfloat32_t z,
svfloat32_t x,
svfloat32_t y) {
const svfloat32_t diff = svsub_f32_x(pg, x, y);
return svmla_f32_x(pg, z, diff, diff);
}
};

template <typename ElementOp>
void fvec_op_ny_sve_d1(float* dis, const float* x, const float* y, size_t ny) {
const size_t lanes = svcntw();
Expand Down Expand Up @@ -513,10 +529,38 @@ void fvec_L2sqr_ny<SIMDLevel::ARM_SVE>(
const float* y,
size_t d,
size_t ny) {
// Use autovectorized L2sqr in a loop
for (size_t i = 0; i < ny; i++) {
dis[i] = fvec_L2sqr<SIMDLevel::ARM_SVE>(x, y, d);
y += d;

const size_t lanes = static_cast<size_t>(svcntw());

switch (d) {
case 1:
fvec_op_ny_sve_d1<ElementOpL2>(dis, x, y, ny);
break;

case 2:
fvec_op_ny_sve_d2<ElementOpL2>(dis, x, y, ny);
break;

case 4:
fvec_op_ny_sve_d4<ElementOpL2>(dis, x, y, ny);
break;

case 8:
fvec_op_ny_sve_d8<ElementOpL2>(dis, x, y, ny);
break;

default:
if (d == lanes)
fvec_op_ny_sve_lanes1<ElementOpL2>(dis, x, y, ny);
else if (d == lanes * 2)
fvec_op_ny_sve_lanes2<ElementOpL2>(dis, x, y, ny);
else if (d == lanes * 3)
fvec_op_ny_sve_lanes3<ElementOpL2>(dis, x, y, ny);
else if (d == lanes * 4)
fvec_op_ny_sve_lanes4<ElementOpL2>(dis, x, y, ny);
else
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
break;
}
}

Expand All @@ -527,22 +571,45 @@ size_t fvec_L2sqr_ny_nearest<SIMDLevel::ARM_SVE>(
const float* y,
size_t d,
size_t ny) {
fvec_L2sqr_ny<SIMDLevel::ARM_SVE>(distances_tmp_buffer, x, y, d, ny);

const size_t lanes = static_cast<size_t>(svcntw());

size_t nearest_idx = 0;
float min_dis = HUGE_VALF;

for (size_t i = 0; i < ny; i++) {
if (distances_tmp_buffer[i] < min_dis) {
min_dis = distances_tmp_buffer[i];
for (size_t i = 0; i < ny; ++i) {
const float* yi = y + i * d;
size_t j = 0;

svfloat32_t accv = svdup_n_f32(0.0f);

for (; j + lanes <= d; j += lanes) {
const svbool_t pg = svptrue_b32();
const svfloat32_t xv = svld1_f32(pg, x + j);
const svfloat32_t yv = svld1_f32(pg, yi + j);
const svfloat32_t diff = svsub_f32_x(pg, xv, yv);
accv = svmla_f32_x(pg, accv, diff, diff);
}

if (j < d) {
const svbool_t pg = svwhilelt_b32_u64(j, d);
const svfloat32_t xv = svld1_f32(pg, x + j);
const svfloat32_t yv = svld1_f32(pg, yi + j);
const svfloat32_t diff = svsub_f32_x(pg, xv, yv);
accv = svmla_f32_x(pg, accv, diff, diff);
}

const float dist = svaddv_f32(svptrue_b32(), accv);
distances_tmp_buffer[i] = dist;
if (dist < min_dis) {
min_dis = dist;
nearest_idx = i;
}
}

return nearest_idx;
}

FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
template <>
void fvec_L2sqr_ny_transposed<SIMDLevel::ARM_SVE>(
float* dis,
Expand All @@ -552,22 +619,33 @@ void fvec_L2sqr_ny_transposed<SIMDLevel::ARM_SVE>(
size_t d,
size_t d_offset,
size_t ny) {
float x_sqlen = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t j = 0; j < d; j++) {
x_sqlen += x[j] * x[j];
}

for (size_t i = 0; i < ny; i++) {
float dp = 0;
FAISS_PRAGMA_IMPRECISE_LOOP
for (size_t j = 0; j < d; j++) {
dp += x[j] * y[i + j * d_offset];
const size_t lanes = static_cast<size_t>(svcntw());
const float x_sq = fvec_norm_L2sqr(x, d);

for (size_t k = 0; k < ny; k += lanes) {
svbool_t pg = svwhilelt_b32_u64(k, ny);
svfloat32_t acc = svdup_n_f32(0.0f);

for (size_t j = 0; j < d; ++j) {
int32_t start_offset = static_cast<int32_t>(j * d_offset);
int32_t stride = 1;
svint32_t offset_vec = svindex_s32(start_offset, stride);

const float *ybase = &y[k];
svfloat32_t ychunk = svld1_gather_index(pg, ybase, offset_vec);

svfloat32_t xj = svdup_n_f32(x[j]);
acc = svmla_f32_x(pg, acc, xj, ychunk);
}
dis[i] = x_sqlen + y_sqlen[i] - 2 * dp;

svfloat32_t ysq = svld1_f32(pg, y_sqlen + k);
svfloat32_t two_acc = svmul_f32_x(pg, acc, svdup_n_f32(2.0f));
svfloat32_t sum = svadd_f32_x(pg, svdup_n_f32(x_sq), ysq);
svfloat32_t res = svsub_f32_x(pg, sum, two_acc);
svst1_f32(pg, dis + k, res);
}
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

template <>
size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::ARM_SVE>(
Expand All @@ -578,20 +656,55 @@ size_t fvec_L2sqr_ny_nearest_y_transposed<SIMDLevel::ARM_SVE>(
size_t d,
size_t d_offset,
size_t ny) {
fvec_L2sqr_ny_transposed<SIMDLevel::ARM_SVE>(
distances_tmp_buffer, x, y, y_sqlen, d, d_offset, ny);

size_t nearest_idx = 0;
float min_dis = HUGE_VALF;
const size_t lanes = svcntw();
const float x_sq = fvec_norm_L2sqr(x, d);

for (size_t i = 0; i < ny; i++) {
if (distances_tmp_buffer[i] < min_dis) {
min_dis = distances_tmp_buffer[i];
nearest_idx = i;
float current_min = HUGE_VALF;
size_t current_min_idx = 0;
svfloat32_t current_min_v = svdup_n_f32(HUGE_VALF);

// Max SVE vector length is 2048 bits → 64 fp32 lanes
float tmp_buf[64];

for (size_t k = 0; k < ny; k += lanes) {
svbool_t pg = svwhilelt_b32(k, ny);
svfloat32_t acc = svdup_n_f32(0.0f);

for (size_t j = 0; j < d; ++j) {
svfloat32_t ychunk = svld1_f32(pg, y + j * d_offset + k);
svfloat32_t xj = svdup_n_f32(x[j]);
acc = svmla_f32_x(pg, acc, xj, ychunk);
}

svfloat32_t ysq = svld1_f32(pg, y_sqlen + k);
svfloat32_t two_acc = svmul_f32_x(pg, acc, svdup_n_f32(2.0f));
svfloat32_t sum = svadd_f32_x(pg, svdup_n_f32(x_sq), ysq);
svfloat32_t res = svsub_f32_x(pg, sum, two_acc);

svst1_f32(pg, distances_tmp_buffer + k, res);

svbool_t less_mask = svcmplt_f32(pg, res, current_min_v);

if (svptest_any(pg, less_mask)) {
float vec_min = svminv_f32(pg, res);

current_min_v = svmin_f32_x(pg, current_min_v, svdup_n_f32(vec_min));

svst1_f32(pg, tmp_buf, res);

size_t cnt = (size_t) svcntw();
for (size_t lane = 0; lane < cnt && (k + lane) < ny; ++lane) {
if (tmp_buf[lane] == vec_min) {
current_min = vec_min;
current_min_idx = k + lane;
break;
}
}
}
}

return nearest_idx;
return current_min_idx;
}

template <>
Expand Down
Loading