11use std:: collections:: HashMap ;
2- use std:: time:: Instant ;
32
43use ndarray:: { ArrayD , ArrayViewD , IxDyn } ;
54use safetensors:: { serialize, SafeTensors } ;
65
76use crate :: {
87 to_arr, ActivationCPULayer , BackendConfig , BatchNorm1DCPULayer , BatchNorm2DCPULayer ,
9- BatchNormTensors , CPUCost , CPULayer , CPUOptimizer , CPUScheduler , Conv2DCPULayer , ConvTensors ,
10- ConvTranspose2DCPULayer , Dataset , DenseCPULayer , DenseTensors , Dropout1DCPULayer ,
11- Dropout2DCPULayer , FlattenCPULayer , GetTensor , Layer , Logger , Pool2DCPULayer , SoftmaxCPULayer ,
12- Tensor , Tensors ,
8+ BatchNormTensors , CPUCost , CPULayer , CPUOptimizer , CPUPostProcessor , CPUScheduler ,
9+ Conv2DCPULayer , ConvTensors , ConvTranspose2DCPULayer , Dataset , DenseCPULayer , DenseTensors ,
10+ Dropout1DCPULayer , Dropout2DCPULayer , FlattenCPULayer , GetTensor , Layer , Logger ,
11+ Pool2DCPULayer , PostProcessor , SoftmaxCPULayer , Tensor , Tensors , Timer ,
1312} ;
1413
1514pub struct Backend {
@@ -23,10 +22,16 @@ pub struct Backend {
2322 pub optimizer : CPUOptimizer ,
2423 pub scheduler : CPUScheduler ,
2524 pub logger : Logger ,
25+ pub timer : Timer ,
2626}
2727
2828impl Backend {
29- pub fn new ( config : BackendConfig , logger : Logger , mut tensors : Option < Vec < Tensors > > ) -> Self {
29+ pub fn new (
30+ config : BackendConfig ,
31+ logger : Logger ,
32+ timer : Timer ,
33+ mut tensors : Option < Vec < Tensors > > ,
34+ ) -> Self {
3035 let mut layers = Vec :: new ( ) ;
3136 let mut size = config. size . clone ( ) ;
3237 for layer in config. layers . iter ( ) {
@@ -99,6 +104,7 @@ impl Backend {
99104 optimizer,
100105 scheduler,
101106 size,
107+ timer,
102108 }
103109 }
104110
@@ -147,7 +153,7 @@ impl Backend {
147153 let mut cost = 0f32 ;
148154 let mut time: u128 ;
149155 let mut total_time = 0u128 ;
150- let start = Instant :: now ( ) ;
156+ let start = ( self . timer . now ) ( ) ;
151157 let total_iter = epochs * datasets. len ( ) ;
152158 while epoch < epochs {
153159 let mut total = 0.0 ;
@@ -160,11 +166,11 @@ impl Backend {
160166 let minibatch = outputs. dim ( ) [ 0 ] ;
161167 if !self . silent && ( ( i + 1 ) * minibatch) % batches == 0 {
162168 cost = total / ( batches) as f32 ;
163- time = start . elapsed ( ) . as_millis ( ) - total_time;
169+ time = ( ( self . timer . now ) ( ) - start ) - total_time;
164170 total_time += time;
165171 let current_iter = epoch * datasets. len ( ) + i;
166172 let msg = format ! (
167- "Epoch={}, Dataset={}, Cost={}, Time={}s, ETA={}s" ,
173+ "Epoch={}, Dataset={}, Cost={}, Time={:.3 }s, ETA={:.3 }s" ,
168174 epoch,
169175 i * minibatch,
170176 cost,
@@ -188,25 +194,20 @@ impl Backend {
188194 } else {
189195 disappointments += 1 ;
190196 if !self . silent {
191- println ! (
197+ ( self . logger . log ) ( format ! (
192198 "Patience counter: {} disappointing epochs out of {}." ,
193199 disappointments, self . patience
194- ) ;
200+ ) ) ;
195201 }
196202 }
197203 if disappointments >= self . patience {
198204 if !self . silent {
199- println ! (
205+ ( self . logger . log ) ( format ! (
200206 "No improvement for {} epochs. Stopping early at cost={}" ,
201207 disappointments, best_cost
202- ) ;
208+ ) ) ;
203209 }
204- let net = Self :: load (
205- & best_net,
206- Logger {
207- log : |x| println ! ( "{}" , x) ,
208- } ,
209- ) ;
210+ let net = Self :: load ( & best_net, self . logger . clone ( ) , self . timer . clone ( ) ) ;
210211 self . layers = net. layers ;
211212 break ;
212213 }
@@ -215,11 +216,18 @@ impl Backend {
215216 }
216217 }
217218
218- pub fn predict ( & mut self , data : ArrayD < f32 > , layers : Option < Vec < usize > > ) -> ArrayD < f32 > {
219+ pub fn predict (
220+ & mut self ,
221+ data : ArrayD < f32 > ,
222+ postprocess : PostProcessor ,
223+ layers : Option < Vec < usize > > ,
224+ ) -> ArrayD < f32 > {
225+ let processor = CPUPostProcessor :: from ( & postprocess) ;
219226 for layer in & mut self . layers {
220227 layer. reset ( 1 ) ;
221228 }
222- self . forward_propagate ( data, false , layers)
229+ let res = self . forward_propagate ( data, false , layers) ;
230+ processor. process ( res)
223231 }
224232
225233 pub fn save ( & self ) -> Vec < u8 > {
@@ -272,7 +280,7 @@ impl Backend {
272280 serialize ( tensors, & Some ( metadata) ) . unwrap ( )
273281 }
274282
275- pub fn load ( buffer : & [ u8 ] , logger : Logger ) -> Self {
283+ pub fn load ( buffer : & [ u8 ] , logger : Logger , timer : Timer ) -> Self {
276284 let tensors = SafeTensors :: deserialize ( buffer) . unwrap ( ) ;
277285 let ( _, metadata) = SafeTensors :: read_metadata ( buffer) . unwrap ( ) ;
278286 let data = metadata. metadata ( ) . as_ref ( ) . unwrap ( ) ;
@@ -304,6 +312,6 @@ impl Backend {
304312 } ;
305313 }
306314
307- Backend :: new ( config, logger, Some ( layers) )
315+ Backend :: new ( config, logger, timer , Some ( layers) )
308316 }
309317}
0 commit comments