@@ -219,7 +219,7 @@ impl<B: Backend> FSRS<B> {
219219 total : items. len ( ) ,
220220 } ;
221221 let model = self . model ( ) ;
222- let mut r_matrix: HashMap < ( u32 , u32 , u32 ) , ( f32 , f32 , f32 ) > = HashMap :: new ( ) ;
222+ let mut r_matrix: HashMap < ( u32 , u32 , u32 ) , ( f32 , f32 , f32 , f32 ) > = HashMap :: new ( ) ;
223223
224224 for chunk in items. chunks ( 512 ) {
225225 let batch = batcher. batch ( chunk. to_vec ( ) ) ;
@@ -231,9 +231,11 @@ impl<B: Backend> FSRS<B> {
231231 all_weights. push ( batch. weights ) ;
232232 izip ! ( chunk, pred, true_val) . for_each ( |( item, p, y) | {
233233 let bin = item. item . r_matrix_index ( ) ;
234- let ( pred, real, weight) = r_matrix. entry ( bin) . or_insert ( ( 0.0 , 0.0 , 0.0 ) ) ;
234+ let ( pred, real, count, weight) =
235+ r_matrix. entry ( bin) . or_insert ( ( 0.0 , 0.0 , 0.0 , 0.0 ) ) ;
235236 * pred += p;
236237 * real += y;
238+ * count += 1.0 ;
237239 * weight += item. weight ;
238240 } ) ;
239241 progress_info. current += chunk. len ( ) ;
@@ -243,13 +245,16 @@ impl<B: Backend> FSRS<B> {
243245 }
244246 let rmse = ( r_matrix
245247 . values ( )
246- . map ( |( pred, real, weight) | {
247- let pred = pred / weight ;
248- let real = real / weight ;
248+ . map ( |( pred, real, count , weight) | {
249+ let pred = pred / count ;
250+ let real = real / count ;
249251 ( pred - real) . powi ( 2 ) * weight
250252 } )
251253 . sum :: < f32 > ( )
252- / r_matrix. values ( ) . map ( |( _, _, weight) | weight) . sum :: < f32 > ( ) )
254+ / r_matrix
255+ . values ( )
256+ . map ( |( _, _, _, weight) | weight)
257+ . sum :: < f32 > ( ) )
253258 . sqrt ( ) ;
254259 let all_retention = Tensor :: cat ( all_retention, 0 ) ;
255260 let all_labels = Tensor :: cat ( all_labels, 0 ) . float ( ) ;
@@ -498,17 +503,17 @@ mod tests {
498503 ] ) ) ?;
499504 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
500505
501- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.212817 , 0.034676 ] ) ;
506+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.212817 , 0.040148 ] ) ;
502507
503508 let fsrs = FSRS :: new ( Some ( & [ ] ) ) ?;
504509 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
505510
506- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.217251 , 0.036590 ] ) ;
511+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.217251 , 0.041336 ] ) ;
507512
508513 let fsrs = FSRS :: new ( Some ( PARAMETERS ) ) ?;
509514 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
510515
511- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.203552 , 0.025646 ] ) ;
516+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.203552 , 0.029828 ] ) ;
512517
513518 let ( self_by_other, other_by_self) = fsrs
514519 . universal_metrics ( items. clone ( ) , & DEFAULT_PARAMETERS , |_| true )
0 commit comments