Skip to content

Commit 84262f0

Browse files
alibeklfcmeta-codesync[bot]
authored andcommitted
Fix d_out check bug and add descriptive error messages to VectorTransform assertions (#5047)
Summary: Pull Request resolved: #5047 ## Summary This diff fixes a bug and improves error message quality in `VectorTransform.cpp`. ### Bug Fix (line 153) `VectorTransform::check_identical()` had a copy-paste bug where `d_in` was checked twice and `d_out` was never checked: ```cpp // Before (buggy): FAISS_THROW_IF_NOT(other.d_in == d_in && other.d_in == d_in); // After (fixed): FAISS_THROW_IF_NOT_MSG( other.d_in == d_in && other.d_out == d_out, "input and output dimensions must match"); ``` This meant two VectorTransforms with matching `d_in` but different `d_out` would incorrectly pass the identity check. This could lead to subtle bugs when comparing or serializing transform chains (e.g., in `IndexPreTransform`). ### Error Message Improvements All 28 bare `FAISS_THROW_IF_NOT()` calls in `VectorTransform.cpp` have been converted to `FAISS_THROW_IF_NOT_MSG()` with clear, actionable error messages. Previously, assertion failures would only show the raw C++ condition (e.g., `"Error: 'p > 0' failed"`), which is unhelpful for users. Now each assertion provides semantic context: - **Dynamic cast failures**: `"failed to cast to HadamardRotation"` instead of `"hr"` - **Dimension mismatches**: `"input and output dimensions must match when PCA is disabled"` instead of `"din == dout"` - **Training state errors**: `"CenteringTransform has not been trained"` instead of `"is_trained"` - **LAPACK errors**: `"LAPACK dgesvd workspace query failed"` instead of `"info == 0"` - **Parameter validation**: `"map entries must be -1 (unused) or valid input dimension indices"` instead of raw condition ### Affected classes - `VectorTransform` (base class) - `LinearTransform` - `HadamardRotation` - `PCAMatrix` - `ITQMatrix` - `ITQTransform` - `OPQMatrix` - `NormalizationTransform` - `CenteringTransform` - `RemapDimensionsTransform` ### Design decisions - Used `FAISS_THROW_IF_NOT_MSG` (not `FAISS_THROW_IF_NOT_FMT`) since all messages are static strings — no runtime formatting needed, keeping zero overhead. - Error messages follow existing Faiss patterns seen in `index_read.cpp` and other files. - Each message describes the semantic meaning of the condition, not just the code. Reviewed By: mnorris11 Differential Revision: D99674067 fbshipit-source-id: cf0fe9a8a7f047013011683d76221682d97beb6c
1 parent ef28c6b commit 84262f0

File tree

1 file changed

+63
-29
lines changed

1 file changed

+63
-29
lines changed

faiss/VectorTransform.cpp

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ void VectorTransform::reverse_transform(idx_t, const float*, float*) const {
150150
}
151151

152152
void VectorTransform::check_identical(const VectorTransform& other) const {
153-
FAISS_THROW_IF_NOT(other.d_in == d_in && other.d_in == d_in);
153+
FAISS_THROW_IF_NOT_MSG(
154+
other.d_in == d_in && other.d_out == d_out,
155+
"transforms must have matching d_in and d_out");
154156
}
155157

156158
/*********************************************
@@ -303,7 +305,9 @@ void LinearTransform::print_if_verbose(
303305
if (!verbose)
304306
return;
305307
printf("matrix %s: %d*%d [\n", name, n, d);
306-
FAISS_THROW_IF_NOT(mat.size() >= static_cast<size_t>(n) * d);
308+
FAISS_THROW_IF_NOT_MSG(
309+
mat.size() >= static_cast<size_t>(n) * d,
310+
"matrix size is too small for the given dimensions");
307311
for (int i = 0; i < n; i++) {
308312
for (int j = 0; j < d; j++) {
309313
printf("%10.5g ", mat[i * d + j]);
@@ -316,8 +320,10 @@ void LinearTransform::print_if_verbose(
316320
void LinearTransform::check_identical(const VectorTransform& other_in) const {
317321
VectorTransform::check_identical(other_in);
318322
auto other = dynamic_cast<const LinearTransform*>(&other_in);
319-
FAISS_THROW_IF_NOT(other);
320-
FAISS_THROW_IF_NOT(other->A == A && other->b == b);
323+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to LinearTransform");
324+
FAISS_THROW_IF_NOT_MSG(
325+
other->A == A && other->b == b,
326+
"LinearTransform matrix A and bias vector b must match");
321327
}
322328

323329
/*********************************************
@@ -390,7 +396,8 @@ static void generate_signs(
390396
std::vector<float>& s1,
391397
std::vector<float>& s2,
392398
std::vector<float>& s3) {
393-
FAISS_THROW_IF_NOT(p > 0);
399+
FAISS_THROW_IF_NOT_MSG(
400+
p > 0, "number of Hadamard factors p must be positive");
394401
SplitMix64RandomGenerator rng(seed);
395402
s1.resize(p);
396403
s2.resize(p);
@@ -426,9 +433,15 @@ void HadamardRotation::apply_noalloc(idx_t n, const float* x, float* xt) const {
426433

427434
size_t d = d_in;
428435
size_t p = d_out;
429-
FAISS_THROW_IF_NOT(signs1.size() == p);
430-
FAISS_THROW_IF_NOT(signs2.size() == p);
431-
FAISS_THROW_IF_NOT(signs3.size() == p);
436+
FAISS_THROW_IF_NOT_MSG(
437+
signs1.size() == p,
438+
"sign-flip vector 1 size must match output dimension");
439+
FAISS_THROW_IF_NOT_MSG(
440+
signs2.size() == p,
441+
"sign-flip vector 2 size must match output dimension");
442+
FAISS_THROW_IF_NOT_MSG(
443+
signs3.size() == p,
444+
"sign-flip vector 3 size must match output dimension");
432445

433446
// Each unnormalized FWHT scales norms by sqrt(p).
434447
// Three rounds scale by p^(3/2). Normalize once at the end.
@@ -468,10 +481,14 @@ void HadamardRotation::apply_noalloc(idx_t n, const float* x, float* xt) const {
468481

469482
void HadamardRotation::check_identical(const VectorTransform& other) const {
470483
auto* hr = dynamic_cast<const HadamardRotation*>(&other);
471-
FAISS_THROW_IF_NOT(hr);
472-
FAISS_THROW_IF_NOT(d_in == hr->d_in);
473-
FAISS_THROW_IF_NOT(d_out == hr->d_out);
474-
FAISS_THROW_IF_NOT(seed == hr->seed);
484+
FAISS_THROW_IF_NOT_MSG(hr, "failed to cast to HadamardRotation");
485+
FAISS_THROW_IF_NOT_MSG(
486+
d_in == hr->d_in, "HadamardRotation input dimensions must match");
487+
FAISS_THROW_IF_NOT_MSG(
488+
d_out == hr->d_out,
489+
"HadamardRotation output dimensions must match");
490+
FAISS_THROW_IF_NOT_MSG(
491+
seed == hr->seed, "HadamardRotation seeds must match");
475492
}
476493

477494
/*********************************************
@@ -731,7 +748,9 @@ void PCAMatrix::train(idx_t n, const float* x_in) {
731748
}
732749

733750
void PCAMatrix::copy_from(const PCAMatrix& other) {
734-
FAISS_THROW_IF_NOT(other.is_trained);
751+
FAISS_THROW_IF_NOT_MSG(
752+
other.is_trained,
753+
"source PCAMatrix must be trained before copying");
735754
mean = other.mean;
736755
eigenvalues = other.eigenvalues;
737756
PCAMat = other.PCAMat;
@@ -761,7 +780,9 @@ void PCAMatrix::prepare_Ab() {
761780
}
762781

763782
if (balanced_bins != 0) {
764-
FAISS_THROW_IF_NOT(d_out % balanced_bins == 0);
783+
FAISS_THROW_IF_NOT_MSG(
784+
d_out % balanced_bins == 0,
785+
"output dimension must be divisible by balanced_bins");
765786
int dsub = d_out / balanced_bins;
766787
std::vector<float> Ain;
767788
std::swap(A, Ain);
@@ -945,7 +966,8 @@ void ITQMatrix::train(idx_t n, const float* xf) {
945966
&lwork,
946967
&info);
947968

948-
FAISS_THROW_IF_NOT(info == 0);
969+
FAISS_THROW_IF_NOT_MSG(
970+
info == 0, "LAPACK dgesvd workspace query failed");
949971
lwork = size_t(lwork1);
950972
std::vector<double> work(lwork);
951973
dgesvd_("A",
@@ -1001,14 +1023,17 @@ ITQTransform::ITQTransform(int din, int dout, bool do_pca_in)
10011023
itq(dout),
10021024
pca_then_itq(din, dout, false) {
10031025
if (!do_pca_in) {
1004-
FAISS_THROW_IF_NOT(din == dout);
1026+
FAISS_THROW_IF_NOT_MSG(
1027+
din == dout,
1028+
"input and output dimensions must match when PCA is disabled");
10051029
}
10061030
max_train_per_dim = 10;
10071031
is_trained = false;
10081032
}
10091033

10101034
void ITQTransform::train(idx_t n, const float* x_in) {
1011-
FAISS_THROW_IF_NOT(!is_trained);
1035+
FAISS_THROW_IF_NOT_MSG(
1036+
!is_trained, "ITQTransform has already been trained");
10121037

10131038
size_t max_train_points = std::max(d_in * max_train_per_dim, 32768);
10141039
const float* x =
@@ -1100,9 +1125,10 @@ void ITQTransform::apply_noalloc(idx_t n, const float* x, float* xt) const {
11001125
void ITQTransform::check_identical(const VectorTransform& other_in) const {
11011126
VectorTransform::check_identical(other_in);
11021127
auto other = dynamic_cast<const ITQTransform*>(&other_in);
1103-
FAISS_THROW_IF_NOT(other);
1128+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to ITQTransform");
11041129
pca_then_itq.check_identical(other->pca_then_itq);
1105-
FAISS_THROW_IF_NOT(other->mean == mean);
1130+
FAISS_THROW_IF_NOT_MSG(
1131+
other->mean == mean, "ITQTransform mean vectors must match");
11061132
}
11071133

11081134
/*********************************************
@@ -1184,7 +1210,8 @@ void OPQMatrix::train(idx_t n, const float* x_in) {
11841210
// we use only the d * d2 upper part of the matrix
11851211
A.resize(d * d2);
11861212
} else {
1187-
FAISS_THROW_IF_NOT(A.size() == d * d2);
1213+
FAISS_THROW_IF_NOT_MSG(
1214+
A.size() == d * d2, "rotation matrix A has incorrect size");
11881215
rotation = A.data();
11891216
}
11901217

@@ -1360,8 +1387,9 @@ void NormalizationTransform::check_identical(
13601387
const VectorTransform& other_in) const {
13611388
VectorTransform::check_identical(other_in);
13621389
auto other = dynamic_cast<const NormalizationTransform*>(&other_in);
1363-
FAISS_THROW_IF_NOT(other);
1364-
FAISS_THROW_IF_NOT(other->norm == norm);
1390+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to NormalizationTransform");
1391+
FAISS_THROW_IF_NOT_MSG(
1392+
other->norm == norm, "normalization type must match");
13651393
}
13661394

13671395
/*********************************************
@@ -1389,7 +1417,8 @@ void CenteringTransform::train(idx_t n, const float* x) {
13891417

13901418
void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
13911419
const {
1392-
FAISS_THROW_IF_NOT(is_trained);
1420+
FAISS_THROW_IF_NOT_MSG(
1421+
is_trained, "CenteringTransform has not been trained");
13931422

13941423
for (idx_t i = 0; i < n; i++) {
13951424
for (int j = 0; j < d_in; j++) {
@@ -1400,7 +1429,8 @@ void CenteringTransform::apply_noalloc(idx_t n, const float* x, float* xt)
14001429

14011430
void CenteringTransform::reverse_transform(idx_t n, const float* xt, float* x)
14021431
const {
1403-
FAISS_THROW_IF_NOT(is_trained);
1432+
FAISS_THROW_IF_NOT_MSG(
1433+
is_trained, "CenteringTransform has not been trained");
14041434

14051435
for (idx_t i = 0; i < n; i++) {
14061436
for (int j = 0; j < d_in; j++) {
@@ -1413,8 +1443,9 @@ void CenteringTransform::check_identical(
14131443
const VectorTransform& other_in) const {
14141444
VectorTransform::check_identical(other_in);
14151445
auto other = dynamic_cast<const CenteringTransform*>(&other_in);
1416-
FAISS_THROW_IF_NOT(other);
1417-
FAISS_THROW_IF_NOT(other->mean == mean);
1446+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to CenteringTransform");
1447+
FAISS_THROW_IF_NOT_MSG(
1448+
other->mean == mean, "CenteringTransform mean vectors must match");
14181449
}
14191450

14201451
/*********************************************
@@ -1429,7 +1460,9 @@ RemapDimensionsTransform::RemapDimensionsTransform(
14291460
map.resize(dout);
14301461
for (int i = 0; i < dout; i++) {
14311462
map[i] = map_in[i];
1432-
FAISS_THROW_IF_NOT(map[i] == -1 || (map[i] >= 0 && map[i] < din));
1463+
FAISS_THROW_IF_NOT_MSG(
1464+
map[i] == -1 || (map[i] >= 0 && map[i] < din),
1465+
"map entries must be -1 (unused) or valid input dimension indices");
14331466
}
14341467
}
14351468

@@ -1486,6 +1519,7 @@ void RemapDimensionsTransform::check_identical(
14861519
const VectorTransform& other_in) const {
14871520
VectorTransform::check_identical(other_in);
14881521
auto other = dynamic_cast<const RemapDimensionsTransform*>(&other_in);
1489-
FAISS_THROW_IF_NOT(other);
1490-
FAISS_THROW_IF_NOT(other->map == map);
1522+
FAISS_THROW_IF_NOT_MSG(other, "failed to cast to RemapDimensionsTransform");
1523+
FAISS_THROW_IF_NOT_MSG(
1524+
other->map == map, "RemapDimensionsTransform maps must match");
14911525
}

0 commit comments

Comments
 (0)