Skip to content

Commit c896acb

Browse files
cmdcolinclaude
andcommitted
Performance optimizations for hierarchical clustering
- Lance-Williams UPGMA update replaces per-pair member enumeration, making distance maintenance O(n) per merge instead of O(cluster_size²) - Active-index list (swap-with-last removal) replaces flag array so the find-minimum loop iterates only live clusters, ~3x fewer comparisons - Distance matrix computed over upper triangle only and mirrored, halving initial pairwise computation - Linked list tracks leaf order in O(1) per merge, replacing index array copies - clock() check moved outside inner distance-matrix loop (was called n² times) - rebuildTree rewritten with stable slot IDs, O(n) with no array splices - clustersGivenK rebuilt using active Set + membership map for stable slot IDs - Remove unused dummy Float32Array distances allocation and field from ClusterResult - Add integration tests covering known inputs end-to-end through real WASM Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 956ac78 commit c896acb

8 files changed

Lines changed: 242 additions & 235 deletions

File tree

src/cluster.ts

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,27 @@ export async function clusterData({
2121
checkCancellation,
2222
})
2323

24-
// Build clustersGivenK from merge information
24+
// Build clustersGivenK from stable-slot merge sequence.
25+
// mergeA[i] and mergeB[i] are stable slot indices; slot mergeA[i] absorbs mergeB[i].
2526
const numSamples = data.length
2627
const clustersGivenK: number[][][] = [[]]
2728

28-
// Start with each sample in its own cluster
29-
const clusterSets: number[][] = Array.from({ length: numSamples }, (_, i) => [
30-
i,
31-
])
29+
const membership = Array.from({ length: numSamples }, (_, i) => [i] as number[])
30+
const activeSlots = new Set(Array.from({ length: numSamples }, (_, i) => i))
3231

3332
for (let i = 0; i < numSamples - 1; i++) {
34-
const [mergeA, mergeB] = result.merges[i]!
33+
const [a, b] = result.merges[i]!
3534

36-
// Record current state
37-
clustersGivenK.push(clusterSets.map(s => [...s]))
35+
clustersGivenK.push([...activeSlots].map(id => [...membership[id]!]))
3836

39-
// Merge clusters
40-
const newCluster = [...clusterSets[mergeA]!, ...clusterSets[mergeB]!]
41-
42-
const removeFirst = Math.max(mergeA, mergeB)
43-
const removeSecond = Math.min(mergeA, mergeB)
44-
45-
clusterSets.splice(removeFirst, 1)
46-
clusterSets.splice(removeSecond, 1)
47-
clusterSets.push(newCluster)
37+
membership[a] = [...membership[a]!, ...membership[b]!]
38+
activeSlots.delete(b!)
4839
}
4940

50-
clustersGivenK.push(clusterSets.map(s => [...s]))
51-
52-
// Create a dummy distance matrix (not used by caller, but part of interface)
53-
const distances = new Float32Array(numSamples * numSamples)
41+
clustersGivenK.push([...activeSlots].map(id => [...membership[id]!]))
5442

5543
return {
5644
tree: result.tree,
57-
distances,
5845
order: result.order,
5946
clustersGivenK: clustersGivenK.reverse(),
6047
}

src/distance.js

-4.74 KB
Binary file not shown.

src/types.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ export interface ClusterNode {
66

77
export interface ClusterResult {
88
tree: ClusterNode
9-
distances: Float32Array
109
order: number[]
1110
clustersGivenK: number[][][]
1211
}

src/wasm-wrapper.ts

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,31 +41,26 @@ export async function hierarchicalClusterWasm(
4141
}
4242
const vectorSize = data[0]?.length ?? 0
4343

44-
// Flatten data
4544
const flatData = new Float32Array(numSamples * vectorSize)
4645
for (let i = 0; i < numSamples; i++) {
4746
flatData.set(data[i]!, i * vectorSize)
4847
}
4948

50-
// Allocate memory
51-
const dataPtr = module._malloc(flatData.length * 4)
49+
const dataPtr = module._malloc(flatData.length * 4)
5250
const heightsPtr = module._malloc((numSamples - 1) * 4)
53-
const mergeAPtr = module._malloc((numSamples - 1) * 4)
54-
const mergeBPtr = module._malloc((numSamples - 1) * 4)
55-
const orderPtr = module._malloc(numSamples * 4)
51+
const mergeAPtr = module._malloc((numSamples - 1) * 4)
52+
const mergeBPtr = module._malloc((numSamples - 1) * 4)
53+
const orderPtr = module._malloc(numSamples * 4)
5654

5755
let callbackPtr: number | null = null
5856

5957
try {
60-
// Copy data to WASM
6158
module.HEAPF32.set(flatData, dataPtr / 4)
6259

63-
// Set up progress callback if provided
6460
if (statusCallback || checkCancellation) {
6561
const progressCallback = (iteration: number, totalIterations: number) => {
6662
checkCancellation?.()
6763
if (statusCallback) {
68-
// Negative iteration indicates distance matrix phase
6964
if (iteration < 0) {
7065
const distancesDone = -iteration
7166
const progress = Math.round((distancesDone / totalIterations) * 100)
@@ -75,15 +70,14 @@ export async function hierarchicalClusterWasm(
7570
statusCallback(`Clustering samples: ${progress}%`)
7671
}
7772
}
78-
return 1 // Continue
73+
return 1
7974
}
8075

8176
callbackPtr = module.addFunction(progressCallback, 'iii')
8277
module._setProgressCallback(callbackPtr)
8378
}
8479

85-
// Run clustering in WASM
86-
module._hierarchicalCluster(
80+
const result = module._hierarchicalCluster(
8781
dataPtr,
8882
numSamples,
8983
vectorSize,
@@ -93,7 +87,10 @@ export async function hierarchicalClusterWasm(
9387
orderPtr,
9488
)
9589

96-
// Copy results back
90+
if (result === -1) {
91+
throw new Error('aborted')
92+
}
93+
9794
const heights = new Float32Array(numSamples - 1)
9895
heights.set(
9996
module.HEAPF32.subarray(heightsPtr / 4, heightsPtr / 4 + numSamples - 1),
@@ -112,7 +109,6 @@ export async function hierarchicalClusterWasm(
112109
const order = new Int32Array(numSamples)
113110
order.set(module.HEAP32.subarray(orderPtr / 4, orderPtr / 4 + numSamples))
114111

115-
// Rebuild tree structure from merge information
116112
const tree = rebuildTree(numSamples, heights, mergeA, mergeB, sampleLabels)
117113
const merges: [number, number][] = []
118114
for (let i = 0; i < numSamples - 1; i++) {
@@ -126,7 +122,6 @@ export async function hierarchicalClusterWasm(
126122
merges,
127123
}
128124
} finally {
129-
// Clean up callback
130125
if (callbackPtr !== null) {
131126
module.removeFunction(callbackPtr)
132127
module._setProgressCallback(0)
@@ -140,45 +135,26 @@ export async function hierarchicalClusterWasm(
140135
}
141136
}
142137

138+
// Rebuilds the tree from stable slot indices (mergeA[i] < mergeB[i] always).
139+
// Slot mergeA[i] absorbs mergeB[i] each iteration, so nodes[0] is always the root.
143140
function rebuildTree(
144141
numSamples: number,
145142
heights: Float32Array,
146143
mergeA: Int32Array,
147144
mergeB: Int32Array,
148145
sampleLabels?: string[],
149146
): ClusterNode {
150-
// Create leaf nodes
151147
const nodes: ClusterNode[] = []
152148
for (let i = 0; i < numSamples; i++) {
153-
nodes.push({
154-
name: sampleLabels?.[i] ?? `Sample ${i}`,
155-
height: 0,
156-
})
149+
nodes.push({ name: sampleLabels?.[i] ?? `Sample ${i}`, height: 0 })
157150
}
158-
159-
// Build tree from merge information
160151
for (let i = 0; i < numSamples - 1; i++) {
161-
const leftIdx = mergeA[i]!
162-
const rightIdx = mergeB[i]!
163-
164-
const leftNode = nodes[leftIdx]!
165-
const rightNode = nodes[rightIdx]!
166-
167-
const newNode: ClusterNode = {
152+
const a = mergeA[i]!, b = mergeB[i]!
153+
nodes[a] = {
168154
name: `Cluster ${i}`,
169155
height: heights[i]!,
170-
children: [leftNode, rightNode],
156+
children: [nodes[a]!, nodes[b]!],
171157
}
172-
173-
// Replace the merged clusters with the new one
174-
// Remove higher index first
175-
const removeFirst = Math.max(leftIdx, rightIdx)
176-
const removeSecond = Math.min(leftIdx, rightIdx)
177-
178-
nodes.splice(removeFirst, 1)
179-
nodes.splice(removeSecond, 1)
180-
nodes.push(newNode)
181158
}
182-
183159
return nodes[0]!
184160
}

0 commit comments

Comments
 (0)