Skip to content

Commit 85a2274

Browse files
committed
start kmeans
1 parent 90412bd commit 85a2274

File tree

5 files changed

+278
-0
lines changed

5 files changed

+278
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#! /usr/bin/env -S nix develop --command runghc -Wall
2+
3+
{-# LANGUAGE BangPatterns #-}
4+
{-# LANGUAGE BlockArguments #-}
5+
{-# LANGUAGE DataKinds #-}
6+
{-# LANGUAGE GADTs #-}
7+
{-# OPTIONS_GHC -Wno-unused-top-binds #-}
8+
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
9+
10+
import Control.Monad
11+
import Control.Monad.ST
12+
import Data.Finite
13+
import Data.Foldable
14+
import Data.Foldable.WithIndex
15+
import Data.Map (Map)
16+
import qualified Data.Map as M
17+
import Data.Proxy
18+
import Data.Set (Set)
19+
import qualified Data.Set as S
20+
import Data.Type.Equality
21+
import Data.Type.Ord
22+
import qualified Data.Vector.Mutable.Sized as MV
23+
import Data.Vector.Sized (Vector)
24+
import qualified Data.Vector.Sized as SV
25+
import qualified Data.Vector as S
26+
import GHC.TypeNats
27+
import Linear.Metric
28+
import Linear.V2
29+
import Linear.Vector
30+
import qualified System.Random.MWC as MWC
31+
import qualified System.Random.MWC.Distributions as MWC
32+
import System.Random.Stateful (StatefulGen (..))
33+
34+
data EM m p = EM
35+
{ emParams :: SV.Vector m p
36+
, emContribs :: SV.Vector m Double
37+
}
38+
39+
expectationMaximization
40+
:: forall m n p a. ()
41+
=> (Int -> EM m p -> EM m p -> Bool)
42+
-> (a -> p -> Double)
43+
-> (SV.Vector n a -> p)
44+
-> SV.Vector n a
45+
-> EM m p
46+
-> EM m p
47+
expectationMaximization stop xs getE getM = go 0
48+
where
49+
go :: Int -> EM m p -> EM m p
50+
go n em0 = EM p' c'
51+
52+
-- initialClusters :: (Additive p, Fractional a, KnownNat k) => [p a] -> Vector k (p a)
53+
-- initialClusters pts = runST do
54+
-- sums <- MV.replicate zero
55+
-- counts <- MV.replicate 0
56+
-- ifor_ pts \i p -> do
57+
-- let i' = modulo (fromIntegral i)
58+
-- MV.modify sums (^+^ p) i'
59+
-- MV.modify counts (+ 1) i'
60+
-- V.generateM \i ->
61+
-- (^/) <$> MV.read sums i <*> (fromInteger <$> MV.read counts i)
62+
63+
-- moveClusters ::
64+
-- forall k p a.
65+
-- (Metric p, Floating a, Ord a, KnownNat k, 1 <= k) =>
66+
-- [p a] ->
67+
-- Vector k (p a) ->
68+
-- Vector k (p a)
69+
-- moveClusters pts origCentroids = runST do
70+
-- sums <- MV.replicate zero
71+
-- counts <- MV.replicate 0
72+
-- for_ pts \p -> do
73+
-- let closestIx = V.minIndex @a @(k - 1) (distance p <$> origCentroids)
74+
-- MV.modify sums (^+^ p) closestIx
75+
-- MV.modify counts (+ 1) closestIx
76+
-- V.generateM \i -> do
77+
-- n <- MV.read counts i
78+
-- if n == 0
79+
-- then pure $ origCentroids `V.index` i
80+
-- else (^/ fromInteger n) <$> MV.read sums i
81+
82+
-- kMeans ::
83+
-- forall k p a.
84+
-- (Metric p, Floating a, Ord a, Eq (p a), KnownNat k, 1 <= k) =>
85+
-- [p a] ->
86+
-- Vector k (p a)
87+
-- kMeans pts = go 0 (initialClusters pts)
88+
-- where
89+
-- go :: Int -> Vector k (p a) -> Vector k (p a)
90+
-- go !i !cs
91+
-- | cs == cs' || i > 100 = cs
92+
-- | otherwise = go (i + 1) cs'
93+
-- where
94+
-- cs' = moveClusters pts cs
95+
96+
-- kMeans' ::
97+
-- forall p a.
98+
-- (Metric p, Floating a, Ord a, Eq (p a)) =>
99+
-- Natural ->
100+
-- [p a] ->
101+
-- [p a]
102+
-- kMeans' k pts = case someNatVal k of
103+
-- SomeNat @k pk -> case cmpNat (Proxy @1) pk of
104+
-- LTI -> toList $ kMeans @k pts -- 1 < k, so 1 <= k is valid
105+
-- EQI -> toList $ kMeans @k pts -- 1 == k, so 1 <= k is valid
106+
-- GTI -> [] -- in this branch, 1 > k, so we cannot call kMeans
107+
108+
-- groupAndSum ::
109+
-- (Metric p, Floating a, Ord a, KnownNat (k + 1)) =>
110+
-- [p a] ->
111+
-- Vector (k + 1) (p a) ->
112+
-- Vector (k + 1) (p a, Integer)
113+
-- groupAndSum pts cs0 = runST do
114+
-- sums <- MV.replicate zero
115+
-- counts <- MV.replicate 0
116+
-- for_ pts \p -> do
117+
-- let closestIx = V.minIndex (distance p <$> cs0)
118+
-- MV.modify sums (^+^ p) closestIx
119+
-- MV.modify counts (+ 1) closestIx
120+
-- V.generateM \i ->
121+
-- (,) <$> MV.read sums i <*> MV.read counts i
122+
123+
-- applyClusters ::
124+
-- forall k p a.
125+
-- (Metric p, Floating a, Ord a, Ord (p a), KnownNat k, 1 <= k) =>
126+
-- [p a] ->
127+
-- Vector k (p a) ->
128+
-- Vector k (Set (p a))
129+
-- applyClusters pts cs = V.generate \i -> M.findWithDefault S.empty i pointsClosestTo
130+
-- where
131+
-- pointsClosestTo :: Map (Finite k) (Set (p a))
132+
-- pointsClosestTo =
133+
-- M.fromListWith
134+
-- (<>)
135+
-- [ (closestIx, S.singleton p)
136+
-- | p <- pts
137+
-- , let closestIx = V.minIndex @a @(k - 1) (distance p <$> cs)
138+
-- ]
139+
140+
-- generateSamples ::
141+
-- forall p g m.
142+
-- (Applicative p, Traversable p, StatefulGen g m) =>
143+
-- -- | number of points per cluster
144+
-- Int ->
145+
-- -- | number of clusters
146+
-- Int ->
147+
-- g ->
148+
-- m ([p Double], [p Double])
149+
-- generateSamples numPts numClusters g = do
150+
-- (centers, ptss) <-
151+
-- unzip <$> replicateM numClusters do
152+
-- -- generate the centroid uniformly in the box component-by-component
153+
-- center <- sequenceA $ pure @p $ MWC.uniformRM (0, boxSize) g
154+
-- -- generate numPts points...
155+
-- pts <-
156+
-- replicateM numPts $
157+
-- -- .. component-by-component, as normal distribution around the center
158+
-- traverse (\c -> MWC.normal c 0.1 g) center
159+
-- pure (center, pts)
160+
-- pure (centers, concat ptss)
161+
-- where
162+
-- -- get the dimension by getting the length of a unit point
163+
-- dim = length (pure () :: p ())
164+
-- -- approximately scale the range of the numbers by the area that the
165+
-- -- clusters would take up
166+
-- boxSize = (fromIntegral numClusters ** recip (fromIntegral dim)) * 20
167+
168+
main :: IO ()
169+
main = do
170+
putStrLn "hi"
171+
-- g <- MWC.createSystemRandom
172+
-- (centers, samps) <- generateSamples @V2 10 3 g
173+
-- putStrLn "* points"
174+
-- mapM_ print samps
175+
-- putStrLn "* actual centers"
176+
-- print centers
177+
-- putStrLn "* kmeans centers"
178+
-- print $ kMeans' 3 samps
179+

code-samples/expectation-maximization/flake.lock

+60
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
description = "expectation-maximization code sample";
3+
inputs = {
4+
nixpkgs.url = "github:NixOS/nixpkgs";
5+
flake-utils.url = "github:numtide/flake-utils";
6+
};
7+
outputs = { self, nixpkgs, flake-utils }:
8+
flake-utils.lib.eachDefaultSystem (system:
9+
let
10+
pkgs = import nixpkgs { inherit system; };
11+
in
12+
{
13+
devShell = pkgs.mkShell {
14+
buildInputs = with pkgs; [
15+
(haskell.packages.ghc981.ghcWithPackages (p: with p; [
16+
finite-typelits
17+
ghc-typelits-natnormalise
18+
linear
19+
mwc-random
20+
statistics
21+
vector-sized
22+
]))
23+
];
24+
};
25+
}
26+
);
27+
}
28+
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
title: "Haskell Nuggets: Expectation-Maximization"
3+
categories: Haskell
4+
tags: haskell, machine learning, dependent types, functional programming
5+
create-time: 2025/03/21 22:34:37
6+
identifier: expectation-maximization
7+
slug: haskell-nuggets-expectation-maximization
8+
series: Haskell Nuggets
9+
---
10+

copy/entries/kmeans.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ create-time: 2024/07/16 22:55:56
66
date: 2024/07/26 12:06:27
77
identifier: kmeans
88
slug: haskell-nuggets-kmeans
9+
series: Haskell Nuggets
910
---
1011

1112
AI is hot, so let's talk about some "classical machine learning" in Haskell

0 commit comments

Comments
 (0)