-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Add standard deviation (std) to KMeans. #5013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| +1 −1 | testsuite/meta/clustering/kmeans.dat |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -130,6 +130,45 @@ void KMeansBase::compute_cluster_variances() | |
| } | ||
| } | ||
|
|
||
| SGMatrix<float64_t> KMeansBase::compute_std_dev() const | ||
| { | ||
| require(cluster_centers.size() > 0, "KMeans is not trained!"); | ||
|
|
||
| SGMatrix<float64_t> points = distance->get_rhs() | ||
| ->as<DenseFeatures<float64_t>>() | ||
| ->get_feature_matrix(); | ||
| SGVector<float64_t> cluster_assignments = const_cast<KMeansBase*>(this) | ||
| ->apply() | ||
| ->as<MulticlassLabels>() | ||
| ->get_labels(); | ||
|
|
||
| SGVector<int32_t> counts(k); | ||
| SGMatrix<float64_t> means = cluster_centers.clone(); | ||
| SGMatrix<float64_t> squares_sums(dimensions, k); | ||
|
|
||
| for (int32_t point_number : range(cluster_assignments.vlen)) | ||
| { | ||
| auto cluster_number = | ||
| static_cast<int32_t>(cluster_assignments[point_number]); | ||
| const auto& point = points.get_column(point_number); | ||
| auto& count = counts[cluster_number]; | ||
| auto mean = means.get_column(cluster_number); | ||
gf712 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto squares_sum = squares_sums.get_column(cluster_number); | ||
|
|
||
| count += 1; | ||
| auto delta1 = linalg::add(point, mean, 1., -1.); | ||
| linalg::add(mean, linalg::scale(delta1, 1. / count), mean); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we compute the mean differently @karlnapf? The issue is just that you are doing a division in a loop.. Is there no geometric series, or something else, that could be cheaper to compute?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is from https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm @vinnik-dmitry07 instead of implementing this online algorithm in here, could you move this somewhere else, so that it is also usable from other parts of the code? I envision an updater being instantiated and then repeatedly being called, a bit like an iterator
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, we would just have a class with a update(datapoint) public class function. Then can have some derived classes for mean and variance (unless you want to combine them in a single class). Not sure what other algos would be useful to have online update for?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes definitely useful! The class should support both mean and variance but each should be optional (one wants one without the other one) |
||
| auto delta2 = linalg::add(point, mean, 1., -1.); | ||
| linalg::add( | ||
| squares_sum, linalg::element_prod(delta1, delta2), squares_sum); | ||
| } | ||
|
|
||
| linalg::scale(squares_sums, squares_sums, 1. / (points.num_cols - 1)); | ||
| for (float64_t& x : squares_sums) | ||
| x = std::sqrt(x); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there no linalg for this? Maybe rewrite as std::transform, to be more idiomatic
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ++ for linalg... could send a separate pr for adding an elementwise sqrt. We had a few elementwise operations added recently, you could check those
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if there is power function and could just have exponent 0.5?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for that other PR |
||
| return squares_sums; | ||
| } | ||
|
|
||
| void KMeansBase::initialize_training(const std::shared_ptr<Features>& data) | ||
| { | ||
| require(distance, "Distance is not provided"); | ||
|
|
@@ -153,7 +192,6 @@ void KMeansBase::initialize_training(const std::shared_ptr<Features>& data) | |
| require(lhs, "Lhs features of distance not provided"); | ||
| int32_t lhs_size=lhs->get_num_vectors(); | ||
| dimensions=lhs->get_num_features(); | ||
| const int32_t centers_size=dimensions*k; | ||
|
|
||
| require(lhs_size>0, "Lhs features should not be empty"); | ||
| require(dimensions>0, "Lhs features should have more than zero dimensions"); | ||
|
|
@@ -318,6 +356,7 @@ void KMeansBase::init() | |
| &use_kmeanspp, "kmeanspp", "Whether to use kmeans++", | ||
| ParameterProperties::HYPER | ParameterProperties::SETTING); | ||
| watch_method("cluster_centers", &KMeansBase::get_cluster_centers); | ||
| watch_method("std_dev", &KMeansBase::compute_std_dev); | ||
| SG_ADD( | ||
| &initial_centers, "initial_centers", "Initial centers", | ||
| ParameterProperties::HYPER); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please no
const_cast. @karlnapf this is probably more of an indication that we need to redesign KMeans no?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably, but can't we avoid the cast otherwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have to either make this function non const or make apply const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making it nonconst breaks watch_method. I tried to make apply const but it causes continuous changes in the derived classes of DistanceMachine. I was not sure that these changes will not bring something bad and decided to use the most obvious solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean it breaks? I think it should work...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run is void, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t think so. Check in SGObject
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, you’re right. Not sure then :D
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void run(std::string_view name) const noexcept(false)Did I miss something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I just got confused, sorry! what @karlnapf suggested should work though! :)