-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.js
More file actions
executable file
·111 lines (94 loc) · 2.98 KB
/
cli.js
File metadata and controls
executable file
·111 lines (94 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env node
process.title = 'mailsac_spam_trainer'
const fs = require('fs')
const sqlite3 = require('sqlite3')
const Classifier = require('./lib/classifier')
const program = require('commander')
const pkg = require('./package.json')
const classifier = new Classifier()
const saveEvery = 100
let fileContents = ''
let db
function loadClassifier (model) {
console.log('loading classifier', model)
classifier.load(model)
console.log('classifier loaded')
}
function saveClassifier (model) {
console.log('saving classifier')
classifier.save(model)
console.log('saved classifier')
}
program.version(pkg.version)
program.command('train <model>')
.description('Train the classifier model')
.option('--db [filepath]', 'Path to the JSON classifier model.')
.option('--table [dbtable]',
'The name of the table in the sqlite database. Must have the following fields: subject, text')
.action((model, options) => {
if (model && !fs.existsSync(model)) {
console.log('creating model from scratch since it does not exist')
classifier.save(model)
}
loadClassifier(model)
console.log('loading database from', options.db)
db = new sqlite3.Database(options.db)
db.serialize(() => {
console.log('training classifier on table', options.table)
let counter = 0
db.each(`SELECT * FROM ${options.table}`, (err, row) => {
if (err) {
console.error(err)
return
}
// format some stuff
let fileContents = ''
if (row.subject) {
fileContents += `Subject: ${row.subject}\n`
}
if (row.text) {
fileContents += row.text
}
if (row.spam === 0) {
console.log(counter, 'Training ham:', row.subject)
classifier.trainHam(fileContents)
} else if (row.spam === 1) {
console.log(counter, 'Training spam:', row.subject)
classifier.trainSpam(fileContents)
} else {
console.error('bad hamOrSpamInt', counter, row)
return
}
counter++
if (counter % saveEvery === 0 && counter !== 0) {
saveClassifier(model)
}
})
})
db.close((err) => {
if (err) {
console.error('failed closing db', err)
process.exit(1)
}
saveClassifier(model)
})
})
program.command('predict <model> <files...>')
.description('Make a prediction using the contents of one or more email text files')
.action((model, files) => {
console.log('Testing prediction', { model, files })
loadClassifier(model)
let result = 0
let currentFilePath
for (let i = 0; i < files.length; i++) {
currentFilePath = files[i]
fileContents = fs.readFileSync(currentFilePath, 'utf-8')
result = classifier.predict(fileContents)
console.log(`Prediction for ${currentFilePath}: ${result} ${result > classifier.pValueSpamMinimum}`)
}
})
program.parse(process.argv)
if (!program.args.length) {
program.help()
process.exit()
}