diff --git a/faiss/Clustering.cpp b/faiss/Clustering.cpp index 918558808a..a3101a712b 100644 --- a/faiss/Clustering.cpp +++ b/faiss/Clustering.cpp @@ -569,7 +569,12 @@ void Clustering::train_encoded( if (i > 0) { float prev_obj = iteration_stats[iteration_stats.size() - 2].obj; - if (obj == prev_obj) { + + double change = (prev_obj == 0) + ? std::numeric_limits::max() + : (prev_obj - obj) / prev_obj; + + if (change >= 0 && change <= early_stop_threshold) { if (verbose) { printf("\n Converged at iteration %d: " "objective did not change\n", diff --git a/faiss/Clustering.h b/faiss/Clustering.h index f5de76d38d..acb501bce8 100644 --- a/faiss/Clustering.h +++ b/faiss/Clustering.h @@ -69,6 +69,12 @@ struct ClusteringParameters { /// Only used when init_method = AFK_MC2. /// Longer chains give better approximation but are slower. uint16_t afkmc2_chain_length = 50; + + /// Early stop threshold, the range is [0, 1]. + /// The value of 0 implies a default Faiss behavior, + /// so the training process stops only if an error + /// is unchanged from the previous iteration. + double early_stop_threshold = 0.0; }; struct ClusteringIterationStats {