Skip to content

Commit 00284bf

Browse files
fix: update internal cache after cluster feature reduction
1 parent 8277be6 commit 00284bf

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

src/hprobes/probe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,10 @@ def fit(
488488
cluster_reps, self._cluster_info_ = _select_cluster_representatives(X_train)
489489
X_train = X_train[:, cluster_reps]
490490
X_val = X_val[:, cluster_reps]
491-
# Update top_k_idx to point to cluster representatives within original indexing
491+
self._X_train_cache = X_train
492+
self._y_train_cache = y_train
493+
self._X_val = X_val
494+
self._y_val = y_val
492495
self._top_k_cluster_reps = self._top_k_idx[cluster_reps]
493496
self._top_k_idx_original = self._top_k_idx.copy()
494497
self._top_k_idx = self._top_k_cluster_reps
@@ -742,6 +745,10 @@ def fit_from_responses(
742745
cluster_reps, self._cluster_info_ = _select_cluster_representatives(X_train)
743746
X_train = X_train[:, cluster_reps]
744747
X_val = X_val[:, cluster_reps]
748+
self._X_val = X_val
749+
self._y_val = y_val
750+
self._X_train_cache = X_train
751+
self._y_train_cache = y_train
745752
self._top_k_cluster_reps = self._top_k_idx[cluster_reps]
746753
self._top_k_idx_original = self._top_k_idx.copy()
747754
self._top_k_idx = self._top_k_cluster_reps

0 commit comments

Comments
 (0)