-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfixed-embedding-generation.js
More file actions
132 lines (108 loc) · 4.82 KB
/
fixed-embedding-generation.js
File metadata and controls
132 lines (108 loc) · 4.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
const fs = require('fs');
const path = require('path');
const Papa = require('papaparse');
const { pipeline } = require('@xenova/transformers');
async function generateEmbeddings() {
// Create output directory
const outputDir = './embeddings_output';
if (!fs.existsSync(outputDir)) {
fs.mkdirSync(outputDir);
}
// Log versions for debugging
console.log(`Node.js version: ${process.version}`);
console.log(`@xenova/transformers version: ${require('@xenova/transformers/package.json').version}`);
// Read CSV file
console.log("Reading CSV file...");
const csvText = fs.readFileSync('./src/data/unique_papers.csv', 'utf-8');
// Parse CSV
const papers = Papa.parse(csvText, {
header: true,
dynamicTyping: true,
skipEmptyLines: true
}).data;
console.log(`Found ${papers.length} papers`);
// Load model
console.log("Loading transformer model...");
const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2');
// Test the model on a sample text to verify it works correctly
const testText = "This is a test to verify the embedding model works correctly";
const testResult = await extractor(testText, { pooling: 'mean', normalize: true });
const testEmbedding = Array.from(testResult.data);
// Calculate the magnitude of the test embedding
const testMagnitude = Math.sqrt(testEmbedding.reduce((sum, val) => sum + val * val, 0));
console.log(`Test embedding magnitude: ${testMagnitude}`);
if (Math.abs(testMagnitude - 1.0) > 0.01) {
console.warn("WARNING: Test embedding is not properly normalized! Expected magnitude close to 1.0");
} else {
console.log("Test embedding is correctly normalized");
}
// Process papers in small batches to avoid memory issues
const batchSize = 20;
let processedPapers = [];
for (let i = 0; i < papers.length; i += batchSize) {
const batch = papers.slice(i, i + batchSize);
console.log(`Processing batch ${i/batchSize + 1}/${Math.ceil(papers.length/batchSize)}`);
for (const paper of batch) {
// Combine title and abstract as input text
const text = `${paper.title || ''} ${paper.abstract || ''}`.trim();
if (!text) {
console.log(`Warning: Paper ${paper.id} has no title or abstract`);
continue;
}
try {
// Generate embedding with explicit normalization
const result = await extractor(text, {
pooling: 'mean',
normalize: true
});
// Convert to regular array and verify normalization
const embedding = Array.from(result.data);
const magnitude = Math.sqrt(embedding.reduce((sum, val) => sum + val * val, 0));
if (Math.abs(magnitude - 1.0) > 0.01) {
// If not normalized, manually normalize
console.warn(`Warning: Embedding for paper ${paper.id} not normalized (mag=${magnitude.toFixed(4)}), normalizing manually`);
const normalizedEmbedding = embedding.map(val => val / magnitude);
// Verify manual normalization
const newMagnitude = Math.sqrt(normalizedEmbedding.reduce((sum, val) => sum + val * val, 0));
console.log(`After manual normalization: ${newMagnitude.toFixed(4)}`);
processedPapers.push({
...paper,
embedding: normalizedEmbedding
});
} else {
processedPapers.push({
...paper,
embedding: embedding
});
}
} catch (error) {
console.error(`Error processing paper ${paper.id}:`, error);
}
}
// Save intermediate results to avoid losing progress
const batchFile = path.join(outputDir, `papers_batch_${i/batchSize + 1}.json`);
fs.writeFileSync(batchFile, JSON.stringify(batch.map(p => ({
id: p.id,
title: p.title,
embedding: processedPapers.find(pp => pp.id === p.id)?.embedding || null
}))));
console.log(`Saved batch to ${batchFile}`);
}
// Save complete results
console.log("Saving all embeddings...");
fs.writeFileSync('papers_with_embeddings.json', JSON.stringify(processedPapers));
console.log("Processing complete! Processed " + processedPapers.length + " papers");
// Final verification: check a few random embeddings
const sampleSize = Math.min(5, processedPapers.length);
console.log(`\nVerifying ${sampleSize} random embeddings:`);
for (let i = 0; i < sampleSize; i++) {
const randomIndex = Math.floor(Math.random() * processedPapers.length);
const paper = processedPapers[randomIndex];
const magnitude = Math.sqrt(paper.embedding.reduce((sum, val) => sum + val * val, 0));
console.log(`Paper "${paper.title.substring(0, 30)}..." - Magnitude: ${magnitude.toFixed(4)}`);
}
}
// Execute main function
generateEmbeddings().catch(error => {
console.error("Error occurred:", error);
});