Skip to content

Commit a6e0da8

Browse files
committed
Improve embeddings comparison demo: better analogies, clustering, and visualizations
- Wrap single words in context sentences for sentence embedding models - Add explicit analogy candidate mappings for common patterns - Upgrade k-means to 50 iterations with k-means++ initialization - Add silhouette score for categorization quality metric - Zoom radar chart and scatter plot axes to highlight model differences - Add explanatory text to sidebar panels (leaderboard, performance, tradeoff)
1 parent b24edfd commit a6e0da8

File tree

4 files changed

+140
-58
lines changed

4 files changed

+140
-58
lines changed

demos/embeddings-comparison/css/embeddings-comparison.css

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,20 @@
491491
}
492492

493493
.panel h3 {
494-
margin: 0 0 15px 0;
494+
margin: 0 0 8px 0;
495495
font-size: 1.1em;
496496
color: var(--primary-color);
497497
}
498498

499+
.panel-desc {
500+
font-size: 0.8em;
501+
color: var(--text-secondary);
502+
margin: 0 0 15px 0;
503+
line-height: 1.5;
504+
padding-bottom: 10px;
505+
border-bottom: 1px solid var(--border-color);
506+
}
507+
499508
/* Leaderboard */
500509
.leaderboard-item {
501510
display: flex;

demos/embeddings-comparison/index.html

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ <h3>Semantic Similarity</h3>
142142
<!-- Analogy Task -->
143143
<div id="analogy-task" class="task-panel">
144144
<h3>Word Analogies</h3>
145-
<p>Test reasoning: "A is to B as C is to ?"</p>
145+
<p>Test reasoning: "A is to B as C is to ?" Note: These sentence embedding models are optimized for sentences, not individual words. Classic analogies (like Word2Vec's king-man+woman=queen) may not work as reliably. Try different examples to see which relationships these models capture best.</p>
146146

147147
<div class="test-inputs">
148148
<div class="input-group">
@@ -189,7 +189,7 @@ <h3>Word Analogies</h3>
189189
<!-- Categorization Task -->
190190
<div id="categorization-task" class="task-panel">
191191
<h3>Topic Categorization</h3>
192-
<p>Test how well models group similar items using k-means clustering.</p>
192+
<p>Test how well models group similar items using k-means clustering. The algorithm uses k-means++ initialization for better results. Cluster quality is measured by silhouette score (higher = better separation).</p>
193193

194194
<div class="test-inputs">
195195
<div class="input-group">
@@ -358,18 +358,21 @@ <h4>Similarity Computation</h4>
358358
<!-- Leaderboard -->
359359
<div class="panel leaderboard-panel">
360360
<h3>Leaderboard</h3>
361+
<p class="panel-desc">Rankings based on the most recent test. Higher scores indicate better semantic understanding for the given task.</p>
361362
<div id="leaderboard"></div>
362363
</div>
363364

364365
<!-- Performance Radar Chart -->
365366
<div class="panel radar-panel">
366367
<h3>Performance Overview</h3>
368+
<p class="panel-desc">Compares models across three dimensions: Quality (similarity/accuracy), Speed (inference time), and Consistency. Axes zoom to highlight differences.</p>
367369
<div id="radar-chart"></div>
368370
</div>
369371

370372
<!-- Speed vs Quality -->
371373
<div class="panel tradeoff-panel">
372374
<h3>Speed vs Quality</h3>
375+
<p class="panel-desc">The fundamental tradeoff: larger models (upper-right) offer better quality but slower inference. Choose based on your latency requirements.</p>
373376
<div id="tradeoff-chart"></div>
374377
</div>
375378

demos/embeddings-comparison/js/benchmark-tasks.js

Lines changed: 84 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@ export class BenchmarkTasks {
4141
candidates = this.generateAnalogyCandidates(wordA, wordB, wordC);
4242
}
4343

44+
// Sentence-based analogy approach: embed words in context sentences
45+
// This works better for sentence embedding models than raw word vectors
46+
const contextTemplate = (word) => `The word "${word}" represents a concept.`;
47+
4448
for (const modelId of modelIds) {
49+
const startTime = performance.now();
50+
51+
// Embed words in sentence context for better representations
4552
const [embA, embB, embC] = await Promise.all([
46-
this.modelsManager.embed(modelId, wordA),
47-
this.modelsManager.embed(modelId, wordB),
48-
this.modelsManager.embed(modelId, wordC)
53+
this.modelsManager.embed(modelId, contextTemplate(wordA)),
54+
this.modelsManager.embed(modelId, contextTemplate(wordB)),
55+
this.modelsManager.embed(modelId, contextTemplate(wordC))
4956
]);
5057

5158
// Calculate analogy vector: B - A + C
@@ -58,7 +65,7 @@ export class BenchmarkTasks {
5865
// Find best match among candidates
5966
const candidateResults = [];
6067
for (const candidate of candidates) {
61-
const embD = await this.modelsManager.embed(modelId, candidate);
68+
const embD = await this.modelsManager.embed(modelId, contextTemplate(candidate));
6269
const similarity = this.modelsManager.cosineSimilarity(
6370
analogyVec,
6471
embD.embedding
@@ -67,14 +74,15 @@ export class BenchmarkTasks {
6774
}
6875

6976
candidateResults.sort((a, b) => b.similarity - a.similarity);
77+
const endTime = performance.now();
7078

7179
results.push({
7280
modelId,
7381
modelName: this.modelsManager.getModelConfig(modelId).name,
7482
prediction: candidateResults[0].word,
7583
confidence: candidateResults[0].similarity,
7684
allCandidates: candidateResults.slice(0, 5),
77-
time: embA.time + embB.time + embC.time
85+
time: endTime - startTime
7886
});
7987
}
8088

@@ -95,40 +103,47 @@ export class BenchmarkTasks {
95103
const lower = (w) => w.toLowerCase();
96104
const a = lower(wordA), b = lower(wordB), c = lower(wordC);
97105

98-
const candidateSets = {
99-
royalty: ['woman', 'girl', 'female', 'lady', 'princess', 'duchess', 'empress'],
100-
capitals: ['England', 'UK', 'Britain', 'Germany', 'Spain', 'Italy', 'Japan', 'China', 'Canada', 'Australia'],
101-
grammar: ['worse', 'worst', 'badly', 'poorly', 'terrible', 'awful'],
102-
tense: ['ran', 'walked', 'jumped', 'swam', 'flew', 'drove', 'ate', 'slept'],
103-
countries: ['French', 'German', 'Spanish', 'Italian', 'Japanese', 'Chinese', 'British', 'American'],
104-
profession: ['actress', 'waitress', 'hostess', 'stewardess', 'heroine', 'woman'],
105-
size: ['tiny', 'small', 'little', 'huge', 'giant', 'massive', 'enormous'],
106-
emotion: ['sad', 'angry', 'scared', 'excited', 'nervous', 'calm', 'joyful']
106+
const analogyMap = {
107+
'king:queen:man': ['woman', 'lady', 'female', 'girl', 'wife', 'mother', 'queen', 'princess'],
108+
'actor:actress:waiter': ['waitress', 'hostess', 'woman', 'female', 'lady', 'stewardess', 'maid'],
109+
'hero:heroine:prince': ['princess', 'lady', 'queen', 'duchess', 'woman', 'girl', 'female'],
110+
'paris:france:london': ['England', 'Britain', 'UK', 'United Kingdom', 'British', 'London', 'Europe'],
111+
'tokyo:japan:berlin': ['Germany', 'German', 'Deutschland', 'Europe', 'Berlin', 'Austria'],
112+
'rome:italy:madrid': ['Spain', 'Spanish', 'Espana', 'Europe', 'Portugal', 'Madrid'],
113+
'good:better:bad': ['worse', 'worst', 'terrible', 'awful', 'poor', 'inferior', 'bad'],
114+
'big:bigger:small': ['smaller', 'tinier', 'little', 'tiny', 'minor', 'lesser', 'small'],
115+
'walk:walked:run': ['ran', 'running', 'runs', 'sprinted', 'jogged', 'run', 'raced'],
116+
'dog:puppy:cat': ['kitten', 'kitty', 'baby cat', 'young cat', 'cub', 'feline', 'cat'],
117+
'wood:tree:paper': ['pulp', 'plant', 'fiber', 'bamboo', 'reed', 'tree', 'forest'],
118+
'hammer:nail:screwdriver': ['screw', 'bolt', 'fastener', 'nut', 'nail', 'pin', 'rivet']
107119
};
108120

121+
const key = `${a}:${b}:${c}`;
122+
if (analogyMap[key]) {
123+
return analogyMap[key];
124+
}
125+
109126
if ((a === 'king' && b === 'queen') || (a === 'man' && b === 'woman') ||
110127
(a === 'boy' && b === 'girl') || (a === 'father' && b === 'mother')) {
111-
return candidateSets.royalty;
128+
return ['woman', 'lady', 'female', 'girl', 'wife', 'mother', 'queen', 'princess', 'daughter'];
112129
}
113130

114-
if (['paris', 'london', 'berlin', 'tokyo', 'rome', 'madrid'].includes(a) ||
115-
['france', 'england', 'germany', 'japan', 'italy', 'spain'].includes(a)) {
116-
return candidateSets.capitals;
131+
if (['paris', 'london', 'berlin', 'tokyo', 'rome', 'madrid'].includes(a)) {
132+
return ['England', 'Britain', 'UK', 'Germany', 'Spain', 'Italy', 'Japan', 'France', 'Europe'];
117133
}
118134

119-
if (['good', 'bad', 'big', 'small', 'fast', 'slow'].includes(a) &&
120-
['better', 'worse', 'bigger', 'smaller', 'faster', 'slower'].includes(b)) {
121-
return candidateSets.grammar;
135+
if (['good', 'bad', 'big', 'small', 'fast', 'slow'].includes(a)) {
136+
return ['worse', 'smaller', 'slower', 'faster', 'bigger', 'better', 'terrible', 'great'];
122137
}
123138

124139
if (['walk', 'run', 'swim', 'fly', 'drive', 'eat'].includes(a)) {
125-
return candidateSets.tense;
140+
return ['ran', 'walked', 'swam', 'flew', 'drove', 'ate', 'slept', 'jumped'];
126141
}
127142

128143
return [
129-
wordB, `${wordB}s`, `${wordC}er`, `${wordC}ing`,
130-
'woman', 'man', 'person', 'thing', 'place',
131-
'good', 'bad', 'big', 'small', 'new', 'old'
144+
'woman', 'man', 'person', 'thing', 'place', 'time',
145+
'good', 'bad', 'big', 'small', 'new', 'old',
146+
wordB, wordC
132147
];
133148
}
134149

@@ -234,21 +249,13 @@ export class BenchmarkTasks {
234249
return validPoints > 0 ? totalScore / validPoints : 0;
235250
}
236251

237-
kMeansClusteringWithCentroids(embeddings, k, maxIters = 10) {
252+
kMeansClusteringWithCentroids(embeddings, k, maxIters = 50) {
238253
const n = embeddings.length;
239254
const dim = embeddings[0].length;
240255

241-
const centroids = [];
242-
const indices = new Set();
243-
while (centroids.length < k) {
244-
const idx = Math.floor(Math.random() * n);
245-
if (!indices.has(idx)) {
246-
centroids.push([...embeddings[idx]]);
247-
indices.add(idx);
248-
}
249-
}
250-
256+
const centroids = this.kMeansPlusPlusInit(embeddings, k);
251257
let assignments = Array(n).fill(0);
258+
let prevAssignments = null;
252259

253260
for (let iter = 0; iter < maxIters; iter++) {
254261
for (let i = 0; i < n; i++) {
@@ -266,6 +273,11 @@ export class BenchmarkTasks {
266273
assignments[i] = bestCluster;
267274
}
268275

276+
if (prevAssignments && assignments.every((a, i) => a === prevAssignments[i])) {
277+
break;
278+
}
279+
prevAssignments = [...assignments];
280+
269281
const clusterSums = Array.from({ length: k }, () => Array(dim).fill(0));
270282
const clusterCounts = Array(k).fill(0);
271283

@@ -289,6 +301,42 @@ export class BenchmarkTasks {
289301
return { assignments, centroids };
290302
}
291303

304+
kMeansPlusPlusInit(embeddings, k) {
305+
const n = embeddings.length;
306+
const centroids = [];
307+
308+
const firstIdx = Math.floor(Math.random() * n);
309+
centroids.push([...embeddings[firstIdx]]);
310+
311+
while (centroids.length < k) {
312+
const distances = embeddings.map(emb => {
313+
let minDist = Infinity;
314+
for (const centroid of centroids) {
315+
const dist = this.euclideanDistance(emb, centroid);
316+
if (dist < minDist) minDist = dist;
317+
}
318+
return minDist * minDist;
319+
});
320+
321+
const totalDist = distances.reduce((a, b) => a + b, 0);
322+
let threshold = Math.random() * totalDist;
323+
324+
for (let i = 0; i < n; i++) {
325+
threshold -= distances[i];
326+
if (threshold <= 0) {
327+
centroids.push([...embeddings[i]]);
328+
break;
329+
}
330+
}
331+
332+
if (centroids.length === centroids.length - 1) {
333+
centroids.push([...embeddings[Math.floor(Math.random() * n)]]);
334+
}
335+
}
336+
337+
return centroids;
338+
}
339+
292340
euclideanDistance(vecA, vecB) {
293341
let sum = 0;
294342
for (let i = 0; i < vecA.length; i++) {

demos/embeddings-comparison/js/visualization.js

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,32 +122,47 @@ export class Visualization {
122122
const theme = this.getPlotlyTheme();
123123
const colors = theme.colorway;
124124

125-
const traces = results.map((result, idx) => {
125+
const allValues = [];
126+
const traceData = results.map((result, idx) => {
126127
const quality = result.similarity || result.confidence || result.avgSimilarity || 0.5;
127-
const speed = result.time ? Math.max(0, 1 - (result.time / 1000)) : 0.5;
128+
const maxTime = Math.max(...results.map(r => r.time || 100));
129+
const speed = result.time ? Math.max(0, 1 - (result.time / maxTime)) : 0.5;
128130
const consistency = quality;
129-
130-
return {
131-
type: 'scatterpolar',
132-
r: [quality, speed, consistency, quality],
133-
theta: ['Quality', 'Speed', 'Consistency', 'Quality'],
134-
fill: 'toself',
135-
name: result.modelName,
136-
opacity: 0.6,
137-
line: { color: colors[idx % colors.length], width: 2 },
138-
fillcolor: colors[idx % colors.length].replace(')', ', 0.3)').replace('rgb', 'rgba'),
139-
marker: { color: colors[idx % colors.length] }
140-
};
131+
allValues.push(quality, speed, consistency);
132+
return { quality, speed, consistency, result, idx };
141133
});
142134

135+
const minVal = Math.min(...allValues);
136+
const maxVal = Math.max(...allValues);
137+
const padding = (maxVal - minVal) * 0.15 || 0.1;
138+
const rangeMin = Math.max(0, minVal - padding);
139+
const rangeMax = Math.min(1, maxVal + padding);
140+
141+
const traces = traceData.map(({ quality, speed, consistency, result, idx }) => ({
142+
type: 'scatterpolar',
143+
r: [quality, speed, consistency, quality],
144+
theta: ['Quality', 'Speed', 'Consistency', 'Quality'],
145+
fill: 'toself',
146+
name: result.modelName,
147+
opacity: 0.6,
148+
line: { color: colors[idx % colors.length], width: 2 },
149+
fillcolor: colors[idx % colors.length].replace(')', ', 0.3)').replace('rgb', 'rgba'),
150+
marker: { color: colors[idx % colors.length] }
151+
}));
152+
153+
const tickCount = 4;
154+
const tickStep = (rangeMax - rangeMin) / tickCount;
155+
const tickvals = Array.from({ length: tickCount + 1 }, (_, i) => rangeMin + i * tickStep);
156+
const ticktext = tickvals.map(v => v.toFixed(2));
157+
143158
const layout = {
144159
polar: {
145160
bgcolor: theme.polar.bgcolor,
146161
radialaxis: {
147162
visible: true,
148-
range: [0, 1],
149-
tickvals: [0.25, 0.5, 0.75, 1],
150-
ticktext: ['0.25', '0.50', '0.75', '1.00'],
163+
range: [rangeMin, rangeMax],
164+
tickvals: tickvals,
165+
ticktext: ticktext,
151166
...theme.polar.radialaxis
152167
},
153168
angularaxis: {
@@ -180,9 +195,16 @@ export class Visualization {
180195
const theme = this.getPlotlyTheme();
181196
const colors = theme.colorway;
182197

198+
const yValues = results.map(r => r.similarity || r.confidence || r.avgSimilarity || 0);
199+
const minY = Math.min(...yValues);
200+
const maxY = Math.max(...yValues);
201+
const paddingY = (maxY - minY) * 0.2 || 0.05;
202+
const yMin = Math.max(0, minY - paddingY);
203+
const yMax = Math.min(1.05, maxY + paddingY);
204+
183205
const trace = {
184206
x: results.map(r => r.time),
185-
y: results.map(r => r.similarity || r.confidence || r.avgSimilarity || 0),
207+
y: yValues,
186208
mode: 'markers+text',
187209
type: 'scatter',
188210
text: results.map(r => r.modelName.split('-')[0]),
@@ -208,7 +230,7 @@ export class Visualization {
208230
},
209231
yaxis: {
210232
title: { text: 'Quality Score', font: theme.yaxis.titlefont },
211-
range: [0, 1.1],
233+
range: [yMin, yMax],
212234
...theme.yaxis,
213235
automargin: true
214236
},

0 commit comments

Comments
 (0)