Skip to content

Commit 23ec785

Browse files
Add Euclidean label assignment regression test
1 parent 409f810 commit 23ec785

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

src/kmeans.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,10 +672,33 @@ pub fn minibatch_fit(
672672

673673
#[cfg(test)]
674674
mod tests {
675-
use super::{MiniBatchState, minibatch_partial_fit};
675+
use super::{MiniBatchState, assign_labels, minibatch_partial_fit};
676676
use crate::errors::ClustorError;
677677
use crate::metrics::Metric;
678678

679+
#[test]
680+
fn assign_labels_euclidean_assigns_each_sample_independently() {
681+
let data = vec![0.0, 0.0, 9.0, 9.0];
682+
let centers = vec![0.0, 0.0, 10.0, 10.0];
683+
let mut labels = vec![usize::MAX; 2];
684+
685+
let inertia = assign_labels(
686+
&data,
687+
None,
688+
&centers,
689+
None,
690+
None,
691+
2,
692+
2,
693+
2,
694+
Metric::Euclidean,
695+
&mut labels,
696+
);
697+
698+
assert_eq!(labels, vec![0, 1]);
699+
assert!((inertia - 2.0).abs() < 1e-12);
700+
}
701+
679702
#[test]
680703
fn minibatch_partial_fit_keeps_cosine_norms_in_sync_with_center_updates() {
681704
let mut state = MiniBatchState {

src/metrics.rs

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ pub enum Metric {
1212

1313
impl Metric {
1414
pub fn parse(s: &str) -> Option<Self> {
15-
match s.to_ascii_lowercase().as_str() {
16-
"euclidean" | "l2" => Some(Metric::Euclidean),
17-
"cosine" => Some(Metric::Cosine),
18-
_ => None,
15+
let bytes = s.as_bytes();
16+
if bytes.eq_ignore_ascii_case(b"l2") || bytes.eq_ignore_ascii_case(b"euclidean") {
17+
Some(Metric::Euclidean)
18+
} else if bytes.eq_ignore_ascii_case(b"cosine") {
19+
Some(Metric::Cosine)
20+
} else {
21+
None
1922
}
2023
}
2124
}
@@ -91,7 +94,26 @@ pub fn normalize_in_place(v: &mut [f64]) {
9194

9295
#[cfg(test)]
9396
mod tests {
94-
use super::cosine_distance;
97+
use super::{Metric, cosine_distance};
98+
99+
#[test]
100+
fn metric_parse_accepts_aliases_case_insensitively() {
101+
assert_eq!(Metric::parse("EUCLIDEAN"), Some(Metric::Euclidean));
102+
assert_eq!(Metric::parse("L2"), Some(Metric::Euclidean));
103+
assert_eq!(Metric::parse("CoSiNe"), Some(Metric::Cosine));
104+
}
105+
106+
#[test]
107+
fn metric_parse_rejects_unknown_values() {
108+
assert_eq!(Metric::parse("manhattan"), None);
109+
assert_eq!(Metric::parse(""), None);
110+
assert_eq!(Metric::parse(" euclidean"), None);
111+
}
112+
113+
#[test]
114+
fn metric_parse_rejects_non_ascii_confusables() {
115+
assert_eq!(Metric::parse("cоsine"), None); // Cyrillic small o
116+
}
95117

96118
#[test]
97119
fn cosine_distance_handles_zero_vectors() {

0 commit comments

Comments
 (0)