Skip to content

Commit 1d5a750

Browse files
authored
feat: Time per epoch and ETA logging when silent=false (#64)
* add time logger * return vec on fit * use Set * use logistic reg
1 parent a3d719e commit 1d5a750

File tree

5 files changed

+58
-24
lines changed

5 files changed

+58
-24
lines changed

crates/core/src/cpu/backend.rs

+37-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashMap;
2+
use std::time::Instant;
23

34
use ndarray::{ArrayD, ArrayViewD, IxDyn};
45
use safetensors::{serialize, SafeTensors};
@@ -110,7 +111,10 @@ impl Backend {
110111
match layers {
111112
Some(layer_indices) => {
112113
for layer_index in layer_indices {
113-
let layer = self.layers.get_mut(layer_index).expect(&format!("Layer #{} does not exist.", layer_index));
114+
let layer = self
115+
.layers
116+
.get_mut(layer_index)
117+
.expect(&format!("Layer #{} does not exist.", layer_index));
114118
inputs = layer.forward_propagate(inputs, training);
115119
}
116120
}
@@ -141,6 +145,10 @@ impl Backend {
141145
let mut disappointments = 0;
142146
let mut best_net = self.save();
143147
let mut cost = 0f32;
148+
let mut time: u128;
149+
let mut total_time = 0u128;
150+
let start = Instant::now();
151+
let total_iter = epochs * datasets.len();
144152
while epoch < epochs {
145153
let mut total = 0.0;
146154
for (i, dataset) in datasets.iter().enumerate() {
@@ -152,7 +160,19 @@ impl Backend {
152160
let minibatch = outputs.dim()[0];
153161
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
154162
cost = total / (batches) as f32;
155-
let msg = format!("Epoch={}, Dataset={}, Cost={}", epoch, i * minibatch, cost);
163+
time = start.elapsed().as_millis() - total_time;
164+
total_time += time;
165+
let current_iter = epoch * datasets.len() + i;
166+
let msg = format!(
167+
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
168+
epoch,
169+
i * minibatch,
170+
cost,
171+
(time as f32) / 1000.0,
172+
(((total_time as f32) / current_iter as f32)
173+
* (total_iter - current_iter) as f32)
174+
/ 1000.0
175+
);
156176
(self.logger.log)(msg);
157177
total = 0.0;
158178
}
@@ -165,17 +185,28 @@ impl Backend {
165185
disappointments = 0;
166186
best_cost = cost;
167187
best_net = self.save();
168-
} else {
188+
} else {
169189
disappointments += 1;
170190
if !self.silent {
171-
println!("Patience counter: {} disappointing epochs out of {}.", disappointments, self.patience);
191+
println!(
192+
"Patience counter: {} disappointing epochs out of {}.",
193+
disappointments, self.patience
194+
);
172195
}
173196
}
174197
if disappointments >= self.patience {
175198
if !self.silent {
176-
println!("No improvement for {} epochs. Stopping early at cost={}", disappointments, best_cost);
199+
println!(
200+
"No improvement for {} epochs. Stopping early at cost={}",
201+
disappointments, best_cost
202+
);
177203
}
178-
let net = Self::load(&best_net, Logger { log: |x| println!("{}", x) });
204+
let net = Self::load(
205+
&best_net,
206+
Logger {
207+
log: |x| println!("{}", x),
208+
},
209+
);
179210
self.layers = net.layers;
180211
break;
181212
}

deno.lock

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/classification/spam.ts

+13-16
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ import {
1414
// Import helpers for metrics
1515
import {
1616
ClassificationReport,
17-
CountVectorizer,
18-
SplitTokenizer,
19-
TfIdfTransformer,
17+
TextCleaner,
18+
TextVectorizer,
2019
// Split the dataset
2120
useSplit,
2221
} from "../../packages/utilities/mod.ts";
22+
import { SigmoidLayer } from "../../mod.ts";
2323

2424
// Define classes
2525
const ymap = ["spam", "ham"];
@@ -32,25 +32,21 @@ const data = parse(_data);
3232
const x = data.map((msg) => msg[1]);
3333

3434
// Get the classes
35-
const y = data.map((msg) => (ymap.indexOf(msg[0]) === 0 ? -1 : 1));
35+
const y = data.map((msg) => (ymap.indexOf(msg[0]) === 0 ? 0 : 1));
3636

3737
// Split the dataset for training and testing
3838
const [train, test] = useSplit({ ratio: [7, 3], shuffle: true }, x, y);
3939

4040
// Vectorize the text messages
4141

42-
const tokenizer = new SplitTokenizer({
43-
skipWords: "english",
44-
standardize: { lowercase: true },
45-
}).fit(train[0]);
42+
const textCleaner = new TextCleaner({ lowercase: true });
4643

47-
const vec = new CountVectorizer(tokenizer.vocabulary.size);
44+
train[0] = textCleaner.clean(train[0])
4845

49-
const x_vec = vec.transform(tokenizer.transform(train[0]), "f32")
46+
const vec = new TextVectorizer("tfidf").fit(train[0]);
5047

51-
const tfidf = new TfIdfTransformer();
48+
const x_vec = vec.transform(train[0], "f32");
5249

53-
const x_tfidf = tfidf.fit(x_vec).transform(x_vec)
5450

5551
// Setup the CPU backend for Netsaur
5652
await setupBackend(CPU);
@@ -73,14 +69,15 @@ const net = new Sequential({
7369
// A dense layer with 1 neuron
7470
DenseLayer({ size: [1] }),
7571
// A sigmoid activation layer
72+
SigmoidLayer()
7673
],
7774

7875
// We are using Log Loss for finding cost
79-
cost: Cost.Hinge,
76+
cost: Cost.BinCrossEntropy,
8077
optimizer: NadamOptimizer(),
8178
});
8279

83-
const inputs = tensor(x_tfidf);
80+
const inputs = tensor(x_vec);
8481

8582
const time = performance.now();
8683
// Train the network
@@ -99,10 +96,10 @@ net.train(
9996

10097
console.log(`training time: ${performance.now() - time}ms`);
10198

102-
const x_vec_test = tfidf.transform(vec.transform(tokenizer.transform(test[0]), "f32"));
99+
const x_vec_test = vec.transform(test[0], "f32");
103100

104101
// Calculate metrics
105102
const res = await net.predict(tensor(x_vec_test));
106-
const y1 = res.data.map((i) => (i < 0 ? -1 : 1));
103+
const y1 = res.data.map((i) => (i < 0.5 ? 0 : 1));
107104
const cMatrix = new ClassificationReport(test[1], y1);
108105
console.log("Confusion Matrix: ", cMatrix);

packages/utilities/src/text/vectorizer.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export class TextVectorizer {
1313
this.mode = mode;
1414
this.mapper = new DiscreteMapper();
1515
}
16-
fit(document: string | string[]) {
16+
fit(document: string | string[]): TextVectorizer {
1717
this.mapper.fit(
1818
(Array.isArray(document) ? document.join(" ") : document).split(" ")
1919
);
@@ -27,6 +27,7 @@ export class TextVectorizer {
2727
this.transformer.fit(this.encoder.transform(tokens, "f32"));
2828
}
2929
}
30+
return this;
3031
}
3132
transform<DT extends DataType>(
3233
document: string | string[],

packages/utilities/src/utils/array/unique.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
*/
77
export function useUnique<T>(arr: ArrayLike<T>): T[] {
88
const array = Array.from(arr);
9-
return array.filter((x, i) => array.indexOf(x) === i);
9+
return [...new Set(array)]
10+
// return array.filter((x, i) => array.indexOf(x) === i);
1011
}

0 commit comments

Comments
 (0)