Skip to content

Commit a811fd1

Browse files
authored
feat: Add postprocessing (and fix WASM) (#65)
* add postprocessing for sign and step * fix wasm
1 parent 1d5a750 commit a811fd1

File tree

22 files changed

+273
-93
lines changed

22 files changed

+273
-93
lines changed

crates/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ safetensors = { workspace = true }
1616
[target.'cfg(target_arch = "wasm32")'.dependencies]
1717
wasm-bindgen = "0.2.92"
1818
getrandom = { version = "0.2", features = ["js"] }
19-
js-sys = "0.3.69"
19+
js-sys = "0.3.69"

crates/core/src/cpu/backend.rs

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
use std::collections::HashMap;
2-
use std::time::Instant;
32

43
use ndarray::{ArrayD, ArrayViewD, IxDyn};
54
use safetensors::{serialize, SafeTensors};
65

76
use crate::{
87
to_arr, ActivationCPULayer, BackendConfig, BatchNorm1DCPULayer, BatchNorm2DCPULayer,
9-
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUScheduler, Conv2DCPULayer, ConvTensors,
10-
ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors, Dropout1DCPULayer,
11-
Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger, Pool2DCPULayer, SoftmaxCPULayer,
12-
Tensor, Tensors,
8+
BatchNormTensors, CPUCost, CPULayer, CPUOptimizer, CPUPostProcessor, CPUScheduler,
9+
Conv2DCPULayer, ConvTensors, ConvTranspose2DCPULayer, Dataset, DenseCPULayer, DenseTensors,
10+
Dropout1DCPULayer, Dropout2DCPULayer, FlattenCPULayer, GetTensor, Layer, Logger,
11+
Pool2DCPULayer, PostProcessor, SoftmaxCPULayer, Tensor, Tensors, Timer,
1312
};
1413

1514
pub struct Backend {
@@ -23,10 +22,16 @@ pub struct Backend {
2322
pub optimizer: CPUOptimizer,
2423
pub scheduler: CPUScheduler,
2524
pub logger: Logger,
25+
pub timer: Timer,
2626
}
2727

2828
impl Backend {
29-
pub fn new(config: BackendConfig, logger: Logger, mut tensors: Option<Vec<Tensors>>) -> Self {
29+
pub fn new(
30+
config: BackendConfig,
31+
logger: Logger,
32+
timer: Timer,
33+
mut tensors: Option<Vec<Tensors>>,
34+
) -> Self {
3035
let mut layers = Vec::new();
3136
let mut size = config.size.clone();
3237
for layer in config.layers.iter() {
@@ -99,6 +104,7 @@ impl Backend {
99104
optimizer,
100105
scheduler,
101106
size,
107+
timer,
102108
}
103109
}
104110

@@ -147,7 +153,7 @@ impl Backend {
147153
let mut cost = 0f32;
148154
let mut time: u128;
149155
let mut total_time = 0u128;
150-
let start = Instant::now();
156+
let start = (self.timer.now)();
151157
let total_iter = epochs * datasets.len();
152158
while epoch < epochs {
153159
let mut total = 0.0;
@@ -160,11 +166,11 @@ impl Backend {
160166
let minibatch = outputs.dim()[0];
161167
if !self.silent && ((i + 1) * minibatch) % batches == 0 {
162168
cost = total / (batches) as f32;
163-
time = start.elapsed().as_millis() - total_time;
169+
time = ((self.timer.now)() - start) - total_time;
164170
total_time += time;
165171
let current_iter = epoch * datasets.len() + i;
166172
let msg = format!(
167-
"Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s",
173+
"Epoch={}, Dataset={}, Cost={}, Time={:.3}s, ETA={:.3}s",
168174
epoch,
169175
i * minibatch,
170176
cost,
@@ -188,25 +194,20 @@ impl Backend {
188194
} else {
189195
disappointments += 1;
190196
if !self.silent {
191-
println!(
197+
(self.logger.log)(format!(
192198
"Patience counter: {} disappointing epochs out of {}.",
193199
disappointments, self.patience
194-
);
200+
));
195201
}
196202
}
197203
if disappointments >= self.patience {
198204
if !self.silent {
199-
println!(
205+
(self.logger.log)(format!(
200206
"No improvement for {} epochs. Stopping early at cost={}",
201207
disappointments, best_cost
202-
);
208+
));
203209
}
204-
let net = Self::load(
205-
&best_net,
206-
Logger {
207-
log: |x| println!("{}", x),
208-
},
209-
);
210+
let net = Self::load(&best_net, self.logger.clone(), self.timer.clone());
210211
self.layers = net.layers;
211212
break;
212213
}
@@ -215,11 +216,18 @@ impl Backend {
215216
}
216217
}
217218

218-
pub fn predict(&mut self, data: ArrayD<f32>, layers: Option<Vec<usize>>) -> ArrayD<f32> {
219+
pub fn predict(
220+
&mut self,
221+
data: ArrayD<f32>,
222+
postprocess: PostProcessor,
223+
layers: Option<Vec<usize>>,
224+
) -> ArrayD<f32> {
225+
let processor = CPUPostProcessor::from(&postprocess);
219226
for layer in &mut self.layers {
220227
layer.reset(1);
221228
}
222-
self.forward_propagate(data, false, layers)
229+
let res = self.forward_propagate(data, false, layers);
230+
processor.process(res)
223231
}
224232

225233
pub fn save(&self) -> Vec<u8> {
@@ -272,7 +280,7 @@ impl Backend {
272280
serialize(tensors, &Some(metadata)).unwrap()
273281
}
274282

275-
pub fn load(buffer: &[u8], logger: Logger) -> Self {
283+
pub fn load(buffer: &[u8], logger: Logger, timer: Timer) -> Self {
276284
let tensors = SafeTensors::deserialize(buffer).unwrap();
277285
let (_, metadata) = SafeTensors::read_metadata(buffer).unwrap();
278286
let data = metadata.metadata().as_ref().unwrap();
@@ -304,6 +312,6 @@ impl Backend {
304312
};
305313
}
306314

307-
Backend::new(config, logger, Some(layers))
315+
Backend::new(config, logger, timer, Some(layers))
308316
}
309317
}

crates/core/src/cpu/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod layers;
66
mod optimizers;
77
mod schedulers;
88
mod regularizer;
9+
mod postprocessing;
910

1011
pub use activation::*;
1112
pub use backend::*;
@@ -14,4 +15,5 @@ pub use init::*;
1415
pub use layers::*;
1516
pub use optimizers::*;
1617
pub use schedulers::*;
17-
pub use regularizer::*;
18+
pub use regularizer::*;
19+
pub use postprocessing::*;
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use ndarray::ArrayD;
2+
use crate::PostProcessor;
3+
4+
mod step;
5+
use step::CPUStepFunction;
6+
7+
pub enum CPUPostProcessor {
8+
None,
9+
Sign,
10+
Step(CPUStepFunction),
11+
}
12+
13+
impl CPUPostProcessor {
14+
pub fn from(processor: &PostProcessor) -> Self {
15+
match processor {
16+
PostProcessor::None => CPUPostProcessor::None,
17+
PostProcessor::Sign => CPUPostProcessor::Sign,
18+
PostProcessor::Step(config) => CPUPostProcessor::Step(CPUStepFunction::new(config)),
19+
}
20+
}
21+
pub fn process(&self, x: ArrayD<f32>) -> ArrayD<f32> {
22+
match self {
23+
CPUPostProcessor::None => x,
24+
CPUPostProcessor::Sign => x.map(|y| y.signum()),
25+
CPUPostProcessor::Step(processor) => x.map(|y| processor.step(*y)),
26+
}
27+
}
28+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use crate::StepFunctionConfig;
2+
3+
pub struct CPUStepFunction {
4+
thresholds: Vec<f32>,
5+
values: Vec<f32>
6+
}
7+
impl CPUStepFunction {
8+
pub fn new(config: &StepFunctionConfig) -> Self {
9+
return Self {
10+
thresholds: config.thresholds.clone(),
11+
values: config.values.clone()
12+
}
13+
}
14+
pub fn step(&self, x: f32) -> f32 {
15+
for (i, &threshold) in self.thresholds.iter().enumerate() {
16+
if x < threshold {
17+
return self.values[i];
18+
}
19+
}
20+
return self.values.last().unwrap().clone()
21+
}
22+
}

crates/core/src/ffi.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use std::slice::{from_raw_parts, from_raw_parts_mut};
2+
use std::time::{SystemTime, UNIX_EPOCH};
23

34
use crate::{
4-
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, TrainOptions,
5-
RESOURCES,
5+
decode_array, decode_json, length, Backend, Dataset, Logger, PredictOptions, Timer,
6+
TrainOptions, RESOURCES,
67
};
78

89
type AllocBufferFn = extern "C" fn(usize) -> *mut u8;
@@ -11,10 +12,17 @@ fn log(string: String) {
1112
println!("{}", string)
1213
}
1314

15+
fn now() -> u128 {
16+
SystemTime::now()
17+
.duration_since(UNIX_EPOCH)
18+
.expect("Your system is behind the Unix Epoch")
19+
.as_millis()
20+
}
21+
1422
#[no_mangle]
1523
pub extern "C" fn ffi_backend_create(ptr: *const u8, len: usize, alloc: AllocBufferFn) -> usize {
1624
let config = decode_json(ptr, len);
17-
let net_backend = Backend::new(config, Logger { log }, None);
25+
let net_backend = Backend::new(config, Logger { log }, Timer { now }, None);
1826
let buf: Vec<u8> = net_backend
1927
.size
2028
.iter()
@@ -75,7 +83,7 @@ pub extern "C" fn ffi_backend_predict(
7583

7684
RESOURCES.with(|cell| {
7785
let mut backend = cell.backend.borrow_mut();
78-
let res = backend[id].predict(inputs, options.layers);
86+
let res = backend[id].predict(inputs, options.post_process, options.layers);
7987
outputs.copy_from_slice(res.as_slice().unwrap());
8088
});
8189
}
@@ -98,7 +106,7 @@ pub extern "C" fn ffi_backend_load(
98106
alloc: AllocBufferFn,
99107
) -> usize {
100108
let buffer = unsafe { from_raw_parts(file_ptr, file_len) };
101-
let net_backend = Backend::load(buffer, Logger { log });
109+
let net_backend = Backend::load(buffer, Logger { log }, Timer { now });
102110
let buf: Vec<u8> = net_backend.size.iter().map(|x| *x as u8).collect();
103111
let size_ptr = alloc(buf.len());
104112
let output_shape = unsafe { from_raw_parts_mut(size_ptr, buf.len()) };

crates/core/src/types.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ pub enum Scheduler {
195195
OneCycle(OneCycleScheduler),
196196
}
197197

198+
#[derive(Serialize, Deserialize, Debug, Clone)]
199+
pub struct StepFunctionConfig {
200+
pub thresholds: Vec<f32>,
201+
pub values: Vec<f32>,
202+
}
203+
204+
#[derive(Serialize, Deserialize, Debug, Clone)]
205+
#[serde(tag = "type", content = "config")]
206+
#[serde(rename_all = "lowercase")]
207+
pub enum PostProcessor {
208+
None,
209+
Sign,
210+
Step(StepFunctionConfig),
211+
}
212+
198213
#[derive(Serialize, Deserialize, Debug, Clone)]
199214
#[serde(rename_all = "camelCase")]
200215
pub struct TrainOptions {
@@ -212,6 +227,7 @@ pub struct PredictOptions {
212227
pub input_shape: Vec<usize>,
213228
pub output_shape: Vec<usize>,
214229
pub layers: Option<Vec<usize>>,
230+
pub post_process: PostProcessor,
215231
}
216232

217233
#[derive(Serialize, Deserialize, Debug, Clone)]

crates/core/src/util.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@ use ndarray::ArrayD;
44
use safetensors::tensor::TensorView;
55
use serde::Deserialize;
66

7+
#[derive(Clone)]
78
pub struct Logger {
89
pub log: fn(string: String) -> (),
910
}
1011

12+
#[derive(Clone)]
13+
pub struct Timer {
14+
pub now: fn() -> u128,
15+
}
16+
1117
pub fn length(shape: Vec<usize>) -> usize {
1218
return shape.iter().fold(1, |i, x| i * x);
1319
}

crates/core/src/wasm.rs

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,39 @@
11
use js_sys::{Array, Float32Array, Uint8Array};
22
use ndarray::ArrayD;
3-
43
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
54

6-
use crate::{Backend, Dataset, Logger, PredictOptions, TrainOptions, RESOURCES};
5+
use crate::{Backend, Dataset, Logger, PredictOptions, Timer, TrainOptions, RESOURCES};
76

87
#[wasm_bindgen]
98
extern "C" {
109
#[wasm_bindgen(js_namespace = console)]
1110
fn log(s: &str);
11+
#[wasm_bindgen(js_namespace = Date)]
12+
fn now() -> f64;
13+
1214
}
1315

1416
fn console_log(string: String) {
1517
log(string.as_str())
1618
}
1719

20+
fn performance_now() -> u128 {
21+
now() as u128
22+
}
23+
1824
#[wasm_bindgen]
1925
pub fn wasm_backend_create(config: String, shape: Array) -> usize {
2026
let config = serde_json::from_str(&config).unwrap();
2127
let mut len = 0;
2228
let logger = Logger { log: console_log };
23-
let net_backend = Backend::new(config, logger, None);
29+
let net_backend = Backend::new(
30+
config,
31+
logger,
32+
Timer {
33+
now: performance_now,
34+
},
35+
None,
36+
);
2437
shape.set_length(net_backend.size.len() as u32);
2538
for (i, s) in net_backend.size.iter().enumerate() {
2639
shape.set(i as u32, JsValue::from(*s))
@@ -37,7 +50,6 @@ pub fn wasm_backend_create(config: String, shape: Array) -> usize {
3750
#[wasm_bindgen]
3851
pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String) {
3952
let options: TrainOptions = serde_json::from_str(&options).unwrap();
40-
4153
let mut datasets = Vec::new();
4254
for i in 0..options.datasets {
4355
let input = buffers[i * 2].to_vec();
@@ -47,7 +59,6 @@ pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String
4759
outputs: ArrayD::from_shape_vec(options.output_shape.clone(), output).unwrap(),
4860
});
4961
}
50-
5162
RESOURCES.with(|cell| {
5263
let mut backend = cell.backend.borrow_mut();
5364
backend[id].train(datasets, options.epochs, options.batches, options.rate)
@@ -59,11 +70,12 @@ pub fn wasm_backend_predict(id: usize, buffer: Float32Array, options: String) ->
5970
let options: PredictOptions = serde_json::from_str(&options).unwrap();
6071
let inputs = ArrayD::from_shape_vec(options.input_shape, buffer.to_vec()).unwrap();
6172

62-
let res = ArrayD::zeros(options.output_shape);
73+
let mut res = ArrayD::zeros(options.output_shape.clone());
6374

6475
RESOURCES.with(|cell| {
6576
let mut backend = cell.backend.borrow_mut();
66-
let _res = backend[id].predict(inputs, options.layers);
77+
let _res = backend[id].predict(inputs, options.post_process, options.layers);
78+
res.assign(&ArrayD::from_shape_vec(options.output_shape, _res.as_slice().unwrap().to_vec()).unwrap());
6779
});
6880
Float32Array::from(res.as_slice().unwrap())
6981
}
@@ -82,7 +94,10 @@ pub fn wasm_backend_save(id: usize) -> Uint8Array {
8294
pub fn wasm_backend_load(buffer: Uint8Array, shape: Array) -> usize {
8395
let mut len = 0;
8496
let logger = Logger { log: console_log };
85-
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger);
97+
let timer = Timer {
98+
now: performance_now,
99+
};
100+
let net_backend = Backend::load(buffer.to_vec().as_slice(), logger, timer);
86101
shape.set_length(net_backend.size.len() as u32);
87102
for (i, s) in net_backend.size.iter().enumerate() {
88103
shape.set(i as u32, JsValue::from(*s))

0 commit comments

Comments
 (0)