|
| 1 | +#! /usr/bin/env -S nix develop --command runghc -Wall |
| 2 | + |
| 3 | +{-# LANGUAGE BangPatterns #-} |
| 4 | +{-# LANGUAGE BlockArguments #-} |
| 5 | +{-# LANGUAGE DataKinds #-} |
| 6 | +{-# LANGUAGE GADTs #-} |
| 7 | +{-# OPTIONS_GHC -Wno-unused-top-binds #-} |
| 8 | +{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} |
| 9 | + |
| 10 | +import Control.Monad |
| 11 | +import Control.Monad.ST |
| 12 | +import Data.Finite |
| 13 | +import Data.Foldable |
| 14 | +import Data.Foldable.WithIndex |
| 15 | +import Data.Map (Map) |
| 16 | +import qualified Data.Map as M |
| 17 | +import Data.Proxy |
| 18 | +import Data.Set (Set) |
| 19 | +import qualified Data.Set as S |
| 20 | +import Data.Type.Equality |
| 21 | +import Data.Type.Ord |
| 22 | +import qualified Data.Vector.Mutable.Sized as MV |
| 23 | +import Data.Vector.Sized (Vector) |
| 24 | +import qualified Data.Vector.Sized as SV |
| 25 | +import qualified Data.Vector as S |
| 26 | +import GHC.TypeNats |
| 27 | +import Linear.Metric |
| 28 | +import Linear.V2 |
| 29 | +import Linear.Vector |
| 30 | +import qualified System.Random.MWC as MWC |
| 31 | +import qualified System.Random.MWC.Distributions as MWC |
| 32 | +import System.Random.Stateful (StatefulGen (..)) |
| 33 | + |
| 34 | +data EM m p = EM |
| 35 | + { emParams :: SV.Vector m p |
| 36 | + , emContribs :: SV.Vector m Double |
| 37 | + } |
| 38 | + |
| 39 | +expectationMaximization |
| 40 | + :: forall m n p a. () |
| 41 | + => (Int -> EM m p -> EM m p -> Bool) |
| 42 | + -> (a -> p -> Double) |
| 43 | + -> (SV.Vector n a -> p) |
| 44 | + -> SV.Vector n a |
| 45 | + -> EM m p |
| 46 | + -> EM m p |
| 47 | +expectationMaximization stop xs getE getM = go 0 |
| 48 | + where |
| 49 | + go :: Int -> EM m p -> EM m p |
| 50 | + go n em0 = EM p' c' |
| 51 | + |
| 52 | +-- initialClusters :: (Additive p, Fractional a, KnownNat k) => [p a] -> Vector k (p a) |
| 53 | +-- initialClusters pts = runST do |
| 54 | +-- sums <- MV.replicate zero |
| 55 | +-- counts <- MV.replicate 0 |
| 56 | +-- ifor_ pts \i p -> do |
| 57 | +-- let i' = modulo (fromIntegral i) |
| 58 | +-- MV.modify sums (^+^ p) i' |
| 59 | +-- MV.modify counts (+ 1) i' |
| 60 | +-- V.generateM \i -> |
| 61 | +-- (^/) <$> MV.read sums i <*> (fromInteger <$> MV.read counts i) |
| 62 | + |
| 63 | +-- moveClusters :: |
| 64 | +-- forall k p a. |
| 65 | +-- (Metric p, Floating a, Ord a, KnownNat k, 1 <= k) => |
| 66 | +-- [p a] -> |
| 67 | +-- Vector k (p a) -> |
| 68 | +-- Vector k (p a) |
| 69 | +-- moveClusters pts origCentroids = runST do |
| 70 | +-- sums <- MV.replicate zero |
| 71 | +-- counts <- MV.replicate 0 |
| 72 | +-- for_ pts \p -> do |
| 73 | +-- let closestIx = V.minIndex @a @(k - 1) (distance p <$> origCentroids) |
| 74 | +-- MV.modify sums (^+^ p) closestIx |
| 75 | +-- MV.modify counts (+ 1) closestIx |
| 76 | +-- V.generateM \i -> do |
| 77 | +-- n <- MV.read counts i |
| 78 | +-- if n == 0 |
| 79 | +-- then pure $ origCentroids `V.index` i |
| 80 | +-- else (^/ fromInteger n) <$> MV.read sums i |
| 81 | + |
| 82 | +-- kMeans :: |
| 83 | +-- forall k p a. |
| 84 | +-- (Metric p, Floating a, Ord a, Eq (p a), KnownNat k, 1 <= k) => |
| 85 | +-- [p a] -> |
| 86 | +-- Vector k (p a) |
| 87 | +-- kMeans pts = go 0 (initialClusters pts) |
| 88 | +-- where |
| 89 | +-- go :: Int -> Vector k (p a) -> Vector k (p a) |
| 90 | +-- go !i !cs |
| 91 | +-- | cs == cs' || i > 100 = cs |
| 92 | +-- | otherwise = go (i + 1) cs' |
| 93 | +-- where |
| 94 | +-- cs' = moveClusters pts cs |
| 95 | + |
| 96 | +-- kMeans' :: |
| 97 | +-- forall p a. |
| 98 | +-- (Metric p, Floating a, Ord a, Eq (p a)) => |
| 99 | +-- Natural -> |
| 100 | +-- [p a] -> |
| 101 | +-- [p a] |
| 102 | +-- kMeans' k pts = case someNatVal k of |
| 103 | +-- SomeNat @k pk -> case cmpNat (Proxy @1) pk of |
| 104 | +-- LTI -> toList $ kMeans @k pts -- 1 < k, so 1 <= k is valid |
| 105 | +-- EQI -> toList $ kMeans @k pts -- 1 == k, so 1 <= k is valid |
| 106 | +-- GTI -> [] -- in this branch, 1 > k, so we cannot call kMeans |
| 107 | + |
| 108 | +-- groupAndSum :: |
| 109 | +-- (Metric p, Floating a, Ord a, KnownNat (k + 1)) => |
| 110 | +-- [p a] -> |
| 111 | +-- Vector (k + 1) (p a) -> |
| 112 | +-- Vector (k + 1) (p a, Integer) |
| 113 | +-- groupAndSum pts cs0 = runST do |
| 114 | +-- sums <- MV.replicate zero |
| 115 | +-- counts <- MV.replicate 0 |
| 116 | +-- for_ pts \p -> do |
| 117 | +-- let closestIx = V.minIndex (distance p <$> cs0) |
| 118 | +-- MV.modify sums (^+^ p) closestIx |
| 119 | +-- MV.modify counts (+ 1) closestIx |
| 120 | +-- V.generateM \i -> |
| 121 | +-- (,) <$> MV.read sums i <*> MV.read counts i |
| 122 | + |
| 123 | +-- applyClusters :: |
| 124 | +-- forall k p a. |
| 125 | +-- (Metric p, Floating a, Ord a, Ord (p a), KnownNat k, 1 <= k) => |
| 126 | +-- [p a] -> |
| 127 | +-- Vector k (p a) -> |
| 128 | +-- Vector k (Set (p a)) |
| 129 | +-- applyClusters pts cs = V.generate \i -> M.findWithDefault S.empty i pointsClosestTo |
| 130 | +-- where |
| 131 | +-- pointsClosestTo :: Map (Finite k) (Set (p a)) |
| 132 | +-- pointsClosestTo = |
| 133 | +-- M.fromListWith |
| 134 | +-- (<>) |
| 135 | +-- [ (closestIx, S.singleton p) |
| 136 | +-- | p <- pts |
| 137 | +-- , let closestIx = V.minIndex @a @(k - 1) (distance p <$> cs) |
| 138 | +-- ] |
| 139 | + |
| 140 | +-- generateSamples :: |
| 141 | +-- forall p g m. |
| 142 | +-- (Applicative p, Traversable p, StatefulGen g m) => |
| 143 | +-- -- | number of points per cluster |
| 144 | +-- Int -> |
| 145 | +-- -- | number of clusters |
| 146 | +-- Int -> |
| 147 | +-- g -> |
| 148 | +-- m ([p Double], [p Double]) |
| 149 | +-- generateSamples numPts numClusters g = do |
| 150 | +-- (centers, ptss) <- |
| 151 | +-- unzip <$> replicateM numClusters do |
| 152 | +-- -- generate the centroid uniformly in the box component-by-component |
| 153 | +-- center <- sequenceA $ pure @p $ MWC.uniformRM (0, boxSize) g |
| 154 | +-- -- generate numPts points... |
| 155 | +-- pts <- |
| 156 | +-- replicateM numPts $ |
| 157 | +-- -- .. component-by-component, as normal distribution around the center |
| 158 | +-- traverse (\c -> MWC.normal c 0.1 g) center |
| 159 | +-- pure (center, pts) |
| 160 | +-- pure (centers, concat ptss) |
| 161 | +-- where |
| 162 | +-- -- get the dimension by getting the length of a unit point |
| 163 | +-- dim = length (pure () :: p ()) |
| 164 | +-- -- approximately scale the range of the numbers by the area that the |
| 165 | +-- -- clusters would take up |
| 166 | +-- boxSize = (fromIntegral numClusters ** recip (fromIntegral dim)) * 20 |
| 167 | + |
| 168 | +main :: IO () |
| 169 | +main = do |
| 170 | + putStrLn "hi" |
| 171 | + -- g <- MWC.createSystemRandom |
| 172 | + -- (centers, samps) <- generateSamples @V2 10 3 g |
| 173 | + -- putStrLn "* points" |
| 174 | + -- mapM_ print samps |
| 175 | + -- putStrLn "* actual centers" |
| 176 | + -- print centers |
| 177 | + -- putStrLn "* kmeans centers" |
| 178 | + -- print $ kMeans' 3 samps |
| 179 | + |
0 commit comments