Skip to content

Commit 7e2bfc1

Browse files
Merge pull request #14 from Kanaries/codex/implement-meanshift-algorithm-with-tests
Add MeanShift clustering
2 parents 9c73bd5 + 9e17e86 commit 7e2bfc1

3 files changed

Lines changed: 118 additions & 1 deletion

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { MeanShift } from '../meanShift';
2+
3+
function reAssignLabel (labels: number[]): number[] {
4+
const encoder: Map<number, number> = new Map();
5+
let ans: number[] = [];
6+
let counter = 0;
7+
for (let i = 0; i < labels.length; i++) {
8+
if (!encoder.has(labels[i])) {
9+
encoder.set(labels[i], counter++);
10+
}
11+
ans.push(encoder.get(labels[i])!);
12+
}
13+
return ans;
14+
}
15+
16+
test('init', () => {
17+
const ms = new MeanShift();
18+
expect(ms).toBeDefined();
19+
});
20+
21+
test('meanshift simple clusters', () => {
22+
const X = [
23+
[0, 0],
24+
[0.1, 0],
25+
[-0.1, 0],
26+
[3, 3],
27+
[3.1, 3],
28+
[3, 3.1]
29+
];
30+
const ms = new MeanShift(1);
31+
const labels = reAssignLabel(ms.fitPredict(X));
32+
const expected = [0, 0, 0, 1, 1, 1];
33+
expect(labels).toEqual(expected);
34+
});

src/clusters/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { KMeans } from './kmeans';
22
import { DBScan } from './dbscan';
3+
import { MeanShift } from './meanShift';
34
import { HDBScan } from './hdbscan';
45

5-
export { KMeans, DBScan, HDBScan };
6+
export { KMeans, DBScan, HDBScan, MeanShift };

src/clusters/meanShift.ts

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import { ClusterBase } from '../base/cluster';
2+
import { Distance } from '../metrics';
3+
4+
export class MeanShift extends ClusterBase {
5+
private bandwidth: number;
6+
private centers: number[][];
7+
private max_iter: number;
8+
private distance: Distance.IDistance;
9+
public constructor(bandwidth: number = 1, max_iter: number = 300, distanceType: Distance.IDistanceType = 'euclidiean') {
10+
super();
11+
this.bandwidth = bandwidth;
12+
this.centers = [];
13+
this.max_iter = max_iter;
14+
this.distance = Distance.useDistance(distanceType);
15+
}
16+
17+
private shiftPoint(point: number[], samplesX: number[][]): number[] {
18+
const neighbors = samplesX.filter(p => this.distance(p, point) <= this.bandwidth);
19+
if (neighbors.length === 0) return point;
20+
const dim = point.length;
21+
const mean = new Array(dim).fill(0);
22+
for (let i = 0; i < neighbors.length; i++) {
23+
for (let j = 0; j < dim; j++) {
24+
mean[j] += neighbors[i][j];
25+
}
26+
}
27+
for (let j = 0; j < dim; j++) {
28+
mean[j] /= neighbors.length;
29+
}
30+
return mean;
31+
}
32+
33+
public fitPredict(samplesX: number[][]): number[] {
34+
let centers = samplesX.map(p => [...p]);
35+
for (let iter = 0; iter < this.max_iter; iter++) {
36+
let moved = false;
37+
for (let i = 0; i < centers.length; i++) {
38+
const newCenter = this.shiftPoint(centers[i], samplesX);
39+
for (let j = 0; j < newCenter.length; j++) {
40+
if (Math.abs(newCenter[j] - centers[i][j]) > 1e-3) moved = true;
41+
centers[i][j] = newCenter[j];
42+
}
43+
}
44+
if (!moved) break;
45+
}
46+
const uniqueCenters: number[][] = [];
47+
const labels = new Array(samplesX.length);
48+
for (let i = 0; i < centers.length; i++) {
49+
let label = -1;
50+
for (let j = 0; j < uniqueCenters.length; j++) {
51+
if (this.distance(centers[i], uniqueCenters[j]) <= this.bandwidth / 2) {
52+
label = j;
53+
break;
54+
}
55+
}
56+
if (label === -1) {
57+
label = uniqueCenters.length;
58+
uniqueCenters.push(centers[i]);
59+
}
60+
labels[i] = label;
61+
}
62+
// assign each original sample to nearest center
63+
for (let i = 0; i < samplesX.length; i++) {
64+
let nearest = 0;
65+
let nearestDis = Infinity;
66+
for (let j = 0; j < uniqueCenters.length; j++) {
67+
const dis = this.distance(samplesX[i], uniqueCenters[j]);
68+
if (dis < nearestDis) {
69+
nearestDis = dis;
70+
nearest = j;
71+
}
72+
}
73+
labels[i] = nearest;
74+
}
75+
this.centers = uniqueCenters;
76+
return labels;
77+
}
78+
79+
public getCentroids() {
80+
return this.centers;
81+
}
82+
}

0 commit comments

Comments
 (0)