@@ -263,7 +263,7 @@ impl<B: Backend> FSRS<B> {
263263 let all_retention = Tensor :: cat ( all_retention, 0 ) ;
264264 let all_labels = Tensor :: cat ( all_labels, 0 ) . float ( ) ;
265265 let all_weights = Tensor :: cat ( all_weights, 0 ) ;
266- let loss = BCELoss :: new ( ) . forward ( all_retention, all_labels, all_weights, Reduction :: Mean ) ;
266+ let loss = BCELoss :: new ( ) . forward ( all_retention, all_labels, all_weights, Reduction :: Auto ) ;
267267 Ok ( ModelEvaluation {
268268 log_loss : loss. to_data ( ) . value [ 0 ] . elem ( ) ,
269269 rmse_bins : rmse,
@@ -502,22 +502,23 @@ mod tests {
502502 let items = [ pretrainset, trainset] . concat ( ) ;
503503
504504 let fsrs = FSRS :: new ( Some ( & [
505- 0.669 , 1.679 , 4.1355 , 9.862 , 7.9435 , 0.9379 , 1.0148 , 0.1588 , 1.3851 , 0.1248 , 0.8421 ,
506- 1.992 , 0.153 , 0.284 , 2.4282 , 0.2547 , 3.1847 , 0.2196 , 0.1906 ,
505+ 0.6032805 , 1.3376843 , 4.4167747 , 9.933699 , 7.654044 , 0.78219295 , 2.336606 , 0.001 ,
506+ 1.3264198 , 0.12967199 , 0.82880765 , 1.9360433 , 0.13298263 , 0.27427456 , 2.4304862 ,
507+ 0.10340813 , 3.108867 , 0.2114512 , 0.2826002 ,
507508 ] ) ) ?;
508509 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
509510
510- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.212817 , 0.040148 ] ) ;
511+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.206160 , 0.025809 ] ) ;
511512
512513 let fsrs = FSRS :: new ( Some ( & [ ] ) ) ?;
513514 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
514515
515- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.217251 , 0.041336 ] ) ;
516+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.223601 , 0.042738 ] ) ;
516517
517518 let fsrs = FSRS :: new ( Some ( PARAMETERS ) ) ?;
518519 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
519520
520- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.203552 , 0.029828 ] ) ;
521+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.208656 , 0.030946 ] ) ;
521522
522523 let ( self_by_other, other_by_self) = fsrs
523524 . universal_metrics ( items. clone ( ) , & DEFAULT_PARAMETERS , |_| true )
0 commit comments