diff --git a/Project.toml b/Project.toml index c6e7d40a..02a8dbf9 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +TableDistances = "e5d66e97-8c70-46bb-8b66-04a2d73ad782" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" @@ -37,6 +38,7 @@ PrettyTables = "2" Random = "1.9" Statistics = "1.9" StatsBase = "0.33, 0.34" +TableDistances = "1.0" Tables = "1.6" TransformsBase = "1.5" Unitful = "1.17" diff --git a/docs/src/transforms.md b/docs/src/transforms.md index f1d6c19b..b0c9bfbc 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -242,6 +242,12 @@ SDS ProjectionPursuit ``` +## KMedoids + +```@docs +KMedoids +``` + ## Closure ```@docs diff --git a/src/TableTransforms.jl b/src/TableTransforms.jl index 9d9805df..e624c1d3 100644 --- a/src/TableTransforms.jl +++ b/src/TableTransforms.jl @@ -6,12 +6,13 @@ module TableTransforms using Tables using Unitful -using Statistics using PrettyTables using AbstractTrees -using LinearAlgebra +using TableDistances using DataScienceTraits using CategoricalArrays +using LinearAlgebra +using Statistics using Random using CoDa @@ -90,6 +91,7 @@ export DRS, SDS, ProjectionPursuit, + KMedoids, Closure, Remainder, Compose, diff --git a/src/transforms.jl b/src/transforms.jl index b7b92de6..e0c6082a 100644 --- a/src/transforms.jl +++ b/src/transforms.jl @@ -286,6 +286,7 @@ include("transforms/quantile.jl") include("transforms/functional.jl") include("transforms/eigenanalysis.jl") include("transforms/projectionpursuit.jl") +include("transforms/kmedoids.jl") include("transforms/closure.jl") include("transforms/remainder.jl") include("transforms/compose.jl") diff --git a/src/transforms/kmedoids.jl b/src/transforms/kmedoids.jl new file mode 100644 index 00000000..6f1ca733 --- /dev/null +++ b/src/transforms/kmedoids.jl @@ -0,0 +1,142 @@ +# ------------------------------------------------------------------ +# Licensed under the MIT License. See LICENSE in the project root. +# ------------------------------------------------------------------ + +""" + KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng()) + +Assign labels to rows of table using the `k`-medoids algorithm. + +The iterative algorithm is interrupted if the relative change on +the average distance to medoids is smaller than a tolerance `tol` +or if the number of iterations exceeds the maximum number of +iterations `maxiter`. + +Optionally, specify a dictionary of `weights` for each column to +affect the underlying table distance from TableDistances.jl, and +a random number generator `rng` to obtain reproducible results. + +## Examples + +```julia +KMedoids(3) +KMedoids(4, maxiter=20) +KMedoids(5, weights=Dict(:col1 => 1.0, :col2 => 2.0)) +``` + +## References + +* Kaufman, L. & Rousseeuw, P. J. 1990. [Partitioning Around Medoids (Program PAM)] + (https://onlinelibrary.wiley.com/doi/10.1002/9780470316801.ch2) + +* Kaufman, L. & Rousseeuw, P. J. 1991. [Finding Groups in Data: An Introduction to Cluster Analysis] + (https://www.jstor.org/stable/2532178) +""" +struct KMedoids{W,RNG} <: StatelessFeatureTransform + k::Int + tol::Float64 + maxiter::Int + weights::W + rng::RNG +end + +function KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng()) + # sanity checks + _assert(k > 0, "number of clusters must be positive") + _assert(tol > 0, "tolerance on relative change must be positive") + _assert(maxiter > 0, "maximum number of iterations must be positive") + KMedoids(k, tol, maxiter, weights, rng) +end + +parameters(transform::KMedoids) = (; k=transform.k) + +function applyfeat(transform::KMedoids, feat, prep) + # retrieve parameters + k = transform.k + tol = transform.tol + maxiter = transform.maxiter + weights = transform.weights + rng = transform.rng + + # number of observations + nobs = _nrows(feat) + + # sanity checks + k > nobs && throw(ArgumentError("requested number of clusters > number of observations")) + + # normalize variables + stdfeat = feat |> StdFeats() + + # define table distance + td = TableDistance(normalize=false, weights=weights) + + # initialize medoids + medoids = sample(rng, 1:nobs, k, replace=false) + + # retrieve distance type + s = Tables.subset(stdfeat, 1:1) + D = eltype(pairwise(td, s)) + + # pre-allocate memory for labels and distances + labels = fill(0, nobs) + dists = fill(typemax(D), nobs) + + # main loop + iter = 0 + δcur = mean(dists) + while iter < maxiter + # update labels and medoids + _updatelabels!(td, stdfeat, medoids, labels, dists) + _updatemedoids!(td, stdfeat, medoids, labels) + + # average distance to medoids + δnew = mean(dists) + + # break upon convergence + abs(δnew - δcur) / δcur < tol && break + + # update and continue + δcur = δnew + iter += 1 + end + + newfeat = (; cluster=labels) |> Tables.materializer(feat) + + newfeat, nothing +end + +function _updatelabels!(td, table, medoids, labels, dists) + for (k, mₖ) in enumerate(medoids) + inds = 1:_nrows(table) + + X = Tables.subset(table, inds) + μ = Tables.subset(table, [mₖ]) + + δ = pairwise(td, X, μ) + + @inbounds for i in inds + if δ[i] < dists[i] + dists[i] = δ[i] + labels[i] = k + end + end + end +end + +function _updatemedoids!(td, table, medoids, labels) + for k in eachindex(medoids) + inds = findall(isequal(k), labels) + + X = Tables.subset(table, inds) + + j = _medoid(td, X) + + @inbounds medoids[k] = inds[j] + end +end + +function _medoid(td, table) + Δ = pairwise(td, table) + _, j = findmin(sum, eachcol(Δ)) + j +end diff --git a/test/transforms.jl b/test/transforms.jl index 6fd34f26..77461353 100644 --- a/test/transforms.jl +++ b/test/transforms.jl @@ -31,6 +31,7 @@ transformfiles = [ "functional.jl", "eigenanalysis.jl", "projectionpursuit.jl", + "kmedoids.jl", "closure.jl", "remainder.jl", "compose.jl", diff --git a/test/transforms/kmedoids.jl b/test/transforms/kmedoids.jl new file mode 100644 index 00000000..5291fdf6 --- /dev/null +++ b/test/transforms/kmedoids.jl @@ -0,0 +1,25 @@ +@testset "KMedoids" begin + @test !isrevertible(KMedoids(3)) + + @test TT.parameters(KMedoids(3)) == (; k=3) + + # basic test with continuous variables + a = [randn(100); 10 .+ randn(100)] + b = [randn(100); 10 .+ randn(100)] + t = Table(; a, b) + n = t |> KMedoids(2; rng) + i1 = findall(isequal(1), n.cluster) + i2 = findall(isequal(2), n.cluster) + @test mean(t.a[i1]) > 5 + @test mean(t.b[i1]) > 5 + @test mean(t.a[i2]) < 5 + @test mean(t.b[i2]) < 5 + + # test with mixed variables + a = [1, 2, 3] + b = [1.0, 2.0, 3.0] + c = ["a", "b", "c"] + t = Table(; a, b, c) + n = t |> KMedoids(3; rng) + @test sort(n.cluster) == [1, 2, 3] +end