@@ -6,7 +6,7 @@ use routee_compass_core::model::{
6
6
use smartcore:: {
7
7
ensemble:: random_forest_regressor:: RandomForestRegressor , linalg:: basic:: matrix:: DenseMatrix ,
8
8
} ;
9
- use std:: { borrow:: Cow , path:: Path } ;
9
+ use std:: { borrow:: Cow , fs :: File , path:: Path } ;
10
10
11
11
pub struct SmartcoreSpeedGradeModel {
12
12
rf : RandomForestRegressor < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > ,
@@ -27,7 +27,13 @@ impl PredictionModel for SmartcoreSpeedGradeModel {
27
27
let mut grade_value = Cow :: Owned ( grade) ;
28
28
speed_unit. convert ( & mut speed_value, & self . speed_unit ) ?;
29
29
grade_unit. convert ( & mut grade_value, & self . grade_unit ) ?;
30
- let x = DenseMatrix :: from_2d_vec ( & vec ! [ vec![ speed_value. as_f64( ) , grade_value. as_f64( ) ] ] ) ;
30
+ let x = DenseMatrix :: from_2d_vec ( & vec ! [ vec![ speed_value. as_f64( ) , grade_value. as_f64( ) ] ] )
31
+ . map_err ( |e| {
32
+ TraversalModelError :: TraversalModelFailure ( format ! (
33
+ "unable to set up prediction input vector: {}" ,
34
+ e
35
+ ) )
36
+ } ) ?;
31
37
let y = self . rf . predict ( & x) . map_err ( |e| {
32
38
TraversalModelError :: TraversalModelFailure ( format ! (
33
39
"failure running underlying Smartcore random forest energy prediction: {}" ,
@@ -47,22 +53,23 @@ impl SmartcoreSpeedGradeModel {
47
53
grade_unit : GradeUnit ,
48
54
energy_rate_unit : EnergyRateUnit ,
49
55
) -> Result < Self , TraversalModelError > {
50
- // Load random forest binary file
51
- let rf_binary = std:: fs:: read ( routee_model_path) . map_err ( |e| {
56
+ let mut file = File :: open ( routee_model_path) . map_err ( |e| {
52
57
TraversalModelError :: BuildError ( format ! (
53
- "failure reading smartcore binary text file {} due to {}" ,
54
- routee_model_path. as_ref( ) . to_str ( ) . unwrap_or_default ( ) ,
58
+ "failure opening file {}: {}" ,
59
+ routee_model_path. as_ref( ) . to_string_lossy ( ) ,
55
60
e
56
61
) )
57
62
} ) ?;
58
63
let rf: RandomForestRegressor < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > =
59
- bincode:: deserialize ( & rf_binary) . map_err ( |e| {
60
- TraversalModelError :: BuildError ( format ! (
61
- "failure deserializing smartcore model {} due to {}" ,
62
- routee_model_path. as_ref( ) . to_str( ) . unwrap_or_default( ) ,
63
- e
64
- ) )
65
- } ) ?;
64
+ bincode:: serde:: decode_from_std_read ( & mut file, bincode:: config:: legacy ( ) ) . map_err (
65
+ |e| {
66
+ TraversalModelError :: BuildError ( format ! (
67
+ "failure deserializing smartcore model {} due to {}" ,
68
+ routee_model_path. as_ref( ) . to_str( ) . unwrap_or_default( ) ,
69
+ e
70
+ ) )
71
+ } ,
72
+ ) ?;
66
73
Ok ( SmartcoreSpeedGradeModel {
67
74
rf,
68
75
speed_unit,
0 commit comments