From 087c96e7c3c754f4b7ad923a93b2a1dacf65d620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 16:39:58 -0300 Subject: [PATCH 1/8] Initial implementation of KMedoids --- Project.toml | 2 + src/TableTransforms.jl | 6 +- src/transforms.jl | 1 + src/transforms/kmedoids.jl | 145 ++++++++++++++++++++++++++++++++++++ test/transforms.jl | 1 + test/transforms/kmedoids.jl | 4 + 6 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 src/transforms/kmedoids.jl create mode 100644 test/transforms/kmedoids.jl diff --git a/Project.toml b/Project.toml index c6e7d40a..02a8dbf9 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TableDistances = "e5d66e97-8c70-46bb-8b66-04a2d73ad782" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" @@ -37,6 +38,7 @@ PrettyTables = "2" Random = "1.9" Statistics = "1.9" StatsBase = "0.33, 0.34" +TableDistances = "1.0" Tables = "1.6" TransformsBase = "1.5" Unitful = "1.17" diff --git a/src/TableTransforms.jl b/src/TableTransforms.jl index 9d9805df..e624c1d3 100644 --- a/src/TableTransforms.jl +++ b/src/TableTransforms.jl @@ -6,12 +6,13 @@ module TableTransforms using Tables using Unitful -using Statistics using PrettyTables using AbstractTrees -using LinearAlgebra +using TableDistances using DataScienceTraits using CategoricalArrays +using LinearAlgebra +using Statistics using Random using CoDa @@ -90,6 +91,7 @@ export DRS, SDS, ProjectionPursuit, + KMedoids, Closure, Remainder, Compose, diff --git a/src/transforms.jl b/src/transforms.jl index b7b92de6..e0c6082a 100644 --- a/src/transforms.jl +++ b/src/transforms.jl @@ -286,6 +286,7 @@ include("transforms/quantile.jl") include("transforms/functional.jl") include("transforms/eigenanalysis.jl") include("transforms/projectionpursuit.jl") +include("transforms/kmedoids.jl") include("transforms/closure.jl") include("transforms/remainder.jl") include("transforms/compose.jl") diff --git a/src/transforms/kmedoids.jl b/src/transforms/kmedoids.jl new file mode 100644 index 00000000..6e9dfb16 --- /dev/null +++ b/src/transforms/kmedoids.jl @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng()) + +Assign labels to rows of table using the `k`-medoids algorithm. + +The iterative algorithm is interrupted if the relative change of +the average dissimilarity between successive iterations is smaller +than a tolerance `tol` or if the number of iterations exceeds +the maximum number of iterations `maxiter`. + +Optionally, specify a dictionary of `weights` for each column to +affect the underlying table distance from TableDistances.jl, and +a random number generator `rng` to obtain reproducible results. + +## Examples + +```julia +KMedoids(3) +KMedoids(4, maxiter=20) +KMedoids(5, weights=Dict(:col1 => 1.0, :col2 => 2.0)) +``` + +## References + +* Kaufman, L. & Rousseeuw, P. J. 1990. [Partitioning Around Medoids (Program PAM)] + (https://onlinelibrary.wiley.com/doi/10.1002/9780470316801.ch2) + +* Kaufman, L. & Rousseeuw, P. J. 1991. [Finding Groups in Data: An Introduction to Cluster Analysis] + (https://www.jstor.org/stable/2532178) +""" +struct KMedoids{W,RNG} <: StatelessFeatureTransform + k::Int + tol::Float64 + maxiter::Int + weights::W + rng::RNG +end + +function KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng()) + # sanity checks + @assert k > 0 "number of clusters must be positive" + @assert tol > 0 "tolerance on relative change must be positive" + @assert maxiter > 0 "maximum number of iterations must be positive" + KMedoids(k, tol, maxiter, weights, rng) +end + +parameters(transform::KMedoids) = (; k=transform.k) + +function applyfeat(transform::KMedoids, feat, prep) + # retrieve parameters + k = transform.k + tol = transform.tol + maxiter = transform.maxiter + weights = transform.weights + rng = transform.rng + + # number of observations + nobs = _nrow(feat) + + # sanity checks + k > nobs && throw(ArgumentError("requested number of clusters > number of observations")) + + # normalize variables + stdfeat = feat |> StdFeats() + + # define table distance + td = TableDistance(normalize=false, weights=weights) + + # initialize medoids + medoids = sample(rng, 1:nobs, k, replace=false) + + # pre-allocate memory for labels and distances + labels = fill(0, nobs) + dists = fill(Inf, nobs) + + # main loop + iter = 0 + δcur = mean(dists) + while iter < maxiter + # update labels and medoids + _updatelabels!(td, stdfeat, medoids, labels, dists) + _updatemedoids!(td, stdfeat, medoids, labels) + + # average dissimilarity + δnew = mean(dists) + + # break upon convergence + abs(δnew - δcur) / δcur < tol && break + + # update and continue + δcur = δnew + iter += 1 + end + + newfeat = (; cluster=labels) |> Tables.materializer(feat) + + newfeat, nothing +end + +function _updatelabels!(td, table, medoids, labels, dists) + for (k, mₖ) in enumerate(medoids) + inds = 1:_nrow(table) + + X = Tables.subset(table, inds) + μ = Tables.subset(table, [mₖ]) + + δ = pairwise(td, X, μ) + + @inbounds for i in inds + if δ[i] < dists[i] + dists[i] = δ[i] + labels[i] = k + end + end + end +end + +function _updatemedoids!(td, table, medoids, labels) + for k in eachindex(medoids) + inds = findall(isequal(k), labels) + + X = Tables.subset(table, inds) + + j = _medoid(td, X) + + @inbounds medoids[k] = inds[j] + end +end + +function _nrow(table) + cols = Tables.columns(table) + vars = Tables.columnnames(cols) + vals = Tables.getcolumn(cols, first(vars)) + length(vals) +end + +function _medoid(td, table) + Δ = pairwise(td, table) + _, j = findmin(sum, eachcol(Δ)) + j +end diff --git a/test/transforms.jl b/test/transforms.jl index 6fd34f26..77461353 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -31,6 +31,7 @@ transformfiles = [ "functional.jl", "eigenanalysis.jl", "projectionpursuit.jl", + "kmedoids.jl", "closure.jl", "remainder.jl", "compose.jl", diff --git a/test/transforms/kmedoids.jl b/test/transforms/kmedoids.jl new file mode 100644 index 00000000..c241b303 --- /dev/null +++ b/test/transforms/kmedoids.jl @@ -0,0 +1,4 @@ +@testset "KMedoids" begin + @test !isrevertible(KMedoids(3)) + @test TT.parameters(KMedoids(3)) == (k=3,) +end From 6643e90b5fa2a0497d1e8f7275f7f3e8957628f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 17:31:40 -0300 Subject: [PATCH 2/8] Add basic test for KMedoids --- test/transforms/kmedoids.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/transforms/kmedoids.jl b/test/transforms/kmedoids.jl index c241b303..e264b7cc 100644 --- a/test/transforms/kmedoids.jl +++ b/test/transforms/kmedoids.jl @@ -1,4 +1,16 @@ @testset "KMedoids" begin @test !isrevertible(KMedoids(3)) @test TT.parameters(KMedoids(3)) == (k=3,) + + a = [randn(100); 10 .+ randn(100)] + b = [randn(100); 10 .+ randn(100)] + t = Table(; a, b) + + c = t |> KMedoids(2; rng) + i1 = findall(isequal(1), c.cluster) + i2 = findall(isequal(2), c.cluster) + @test mean(t.a[i1]) > 5 + @test mean(t.b[i1]) > 5 + @test mean(t.a[i2]) < 5 + @test mean(t.b[i2]) < 5 end From 9d5f3ff618390656af5dacff38953660df69d826 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 17:31:50 -0300 Subject: [PATCH 3/8] Add KMedoids to docs --- docs/src/transforms.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/transforms.md b/docs/src/transforms.md index f1d6c19b..b0c9bfbc 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -242,6 +242,12 @@ SDS ProjectionPursuit ``` +## KMedoids + +```@docs +KMedoids +``` + ## Closure ```@docs From a8c32c3128f60aa15ba5c468e381f2875d30111a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 17:49:15 -0300 Subject: [PATCH 4/8] Add more tests for KMedoids --- test/transforms/kmedoids.jl | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/test/transforms/kmedoids.jl b/test/transforms/kmedoids.jl index e264b7cc..5291fdf6 100644 --- a/test/transforms/kmedoids.jl +++ b/test/transforms/kmedoids.jl @@ -1,16 +1,25 @@ @testset "KMedoids" begin @test !isrevertible(KMedoids(3)) - @test TT.parameters(KMedoids(3)) == (k=3,) + @test TT.parameters(KMedoids(3)) == (; k=3) + + # basic test with continuous variables a = [randn(100); 10 .+ randn(100)] b = [randn(100); 10 .+ randn(100)] t = Table(; a, b) - - c = t |> KMedoids(2; rng) - i1 = findall(isequal(1), c.cluster) - i2 = findall(isequal(2), c.cluster) + n = t |> KMedoids(2; rng) + i1 = findall(isequal(1), n.cluster) + i2 = findall(isequal(2), n.cluster) @test mean(t.a[i1]) > 5 @test mean(t.b[i1]) > 5 @test mean(t.a[i2]) < 5 @test mean(t.b[i2]) < 5 + + # test with mixed variables + a = [1, 2, 3] + b = [1.0, 2.0, 3.0] + c = ["a", "b", "c"] + t = Table(; a, b, c) + n = t |> KMedoids(3; rng) + @test sort(n.cluster) == [1, 2, 3] end From 08f3cc357f7c931b7ca3c92cb7221bee962e8aa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 17:51:23 -0300 Subject: [PATCH 5/8] Use existing _nrows utility --- src/transforms/kmedoids.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/transforms/kmedoids.jl b/src/transforms/kmedoids.jl index 6e9dfb16..5411d28b 100644 --- a/src/transforms/kmedoids.jl +++ b/src/transforms/kmedoids.jl @@ -59,7 +59,7 @@ function applyfeat(transform::KMedoids, feat, prep) rng = transform.rng # number of observations - nobs = _nrow(feat) + nobs = _nrows(feat) # sanity checks k > nobs && throw(ArgumentError("requested number of clusters > number of observations")) @@ -103,7 +103,7 @@ end function _updatelabels!(td, table, medoids, labels, dists) for (k, mₖ) in enumerate(medoids) - inds = 1:_nrow(table) + inds = 1:_nrows(table) X = Tables.subset(table, inds) μ = Tables.subset(table, [mₖ]) @@ -131,13 +131,6 @@ function _updatemedoids!(td, table, medoids, labels) end end -function _nrow(table) - cols = Tables.columns(table) - vars = Tables.columnnames(cols) - vals = Tables.getcolumn(cols, first(vars)) - length(vals) -end - function _medoid(td, table) Δ = pairwise(td, table) _, j = findmin(sum, eachcol(Δ)) From b9f56fbd9874243fd44acc67779973be0f280105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 18:43:29 -0300 Subject: [PATCH 6/8] Use _assert utility function --- src/transforms/kmedoids.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transforms/kmedoids.jl b/src/transforms/kmedoids.jl index 5411d28b..382667b9 100644 --- a/src/transforms/kmedoids.jl +++ b/src/transforms/kmedoids.jl @@ -42,9 +42,9 @@ end function KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng()) # sanity checks - @assert k > 0 "number of clusters must be positive" - @assert tol > 0 "tolerance on relative change must be positive" - @assert maxiter > 0 "maximum number of iterations must be positive" + _assert(k > 0, "number of clusters must be positive") + _assert(tol > 0, "tolerance on relative change must be positive") + _assert(maxiter > 0, "maximum number of iterations must be positive") KMedoids(k, tol, maxiter, weights, rng) end From 260c1561f6fff9460603486265d5bb97091009eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 20:36:24 -0300 Subject: [PATCH 7/8] Retrieve distance type --- src/transforms/kmedoids.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transforms/kmedoids.jl b/src/transforms/kmedoids.jl index 382667b9..b3eecca4 100644 --- a/src/transforms/kmedoids.jl +++ b/src/transforms/kmedoids.jl @@ -73,9 +73,13 @@ function applyfeat(transform::KMedoids, feat, prep) # initialize medoids medoids = sample(rng, 1:nobs, k, replace=false) + # retrieve distance type + row = Tables.subset(stdfeat, 1:1) + D = eltype(pairwise(td, row)) + # pre-allocate memory for labels and distances labels = fill(0, nobs) - dists = fill(Inf, nobs) + dists = fill(typemax(D), nobs) # main loop iter = 0 From 8ad13238659414fee0ddf32c29d9d61796a4ef02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Mar 2025 20:39:55 -0300 Subject: [PATCH 8/8] Minor adjustments --- src/transforms/kmedoids.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transforms/kmedoids.jl b/src/transforms/kmedoids.jl index b3eecca4..6f1ca733 100644 --- a/src/transforms/kmedoids.jl +++ b/src/transforms/kmedoids.jl @@ -7,10 +7,10 @@ Assign labels to rows of table using the `k`-medoids algorithm. -The iterative algorithm is interrupted if the relative change of -the average dissimilarity between successive iterations is smaller -than a tolerance `tol` or if the number of iterations exceeds -the maximum number of iterations `maxiter`. +The iterative algorithm is interrupted if the relative change on +the average distance to medoids is smaller than a tolerance `tol` +or if the number of iterations exceeds the maximum number of +iterations `maxiter`. Optionally, specify a dictionary of `weights` for each column to affect the underlying table distance from TableDistances.jl, and @@ -74,8 +74,8 @@ function applyfeat(transform::KMedoids, feat, prep) medoids = sample(rng, 1:nobs, k, replace=false) # retrieve distance type - row = Tables.subset(stdfeat, 1:1) - D = eltype(pairwise(td, row)) + s = Tables.subset(stdfeat, 1:1) + D = eltype(pairwise(td, s)) # pre-allocate memory for labels and distances labels = fill(0, nobs) @@ -89,7 +89,7 @@ function applyfeat(transform::KMedoids, feat, prep) _updatelabels!(td, stdfeat, medoids, labels, dists) _updatemedoids!(td, stdfeat, medoids, labels) - # average dissimilarity + # average distance to medoids δnew = mean(dists) # break upon convergence