@@ -6,8 +6,7 @@ use burn::nn::loss::Reduction;
66use burn:: tensor:: { Data , Shape , Tensor } ;
77use burn:: { data:: dataloader:: batcher:: Batcher , tensor:: backend:: Backend } ;
88
9- use crate :: dataset:: FSRSBatcher ;
10- use crate :: dataset:: { simple_weighted_fsrs_items, FSRSBatch } ;
9+ use crate :: dataset:: { recency_weighted_fsrs_items, FSRSBatch , FSRSBatcher } ;
1110use crate :: error:: Result ;
1211use crate :: model:: Model ;
1312use crate :: training:: BCELoss ;
@@ -210,7 +209,7 @@ impl<B: Backend> FSRS<B> {
210209 if items. is_empty ( ) {
211210 return Err ( FSRSError :: NotEnoughData ) ;
212211 }
213- let items = simple_weighted_fsrs_items ( items) ;
212+ let items = recency_weighted_fsrs_items ( items) ;
214213 let batcher = FSRSBatcher :: new ( self . device ( ) ) ;
215214 let mut all_retention = vec ! [ ] ;
216215 let mut all_labels = vec ! [ ] ;
@@ -232,10 +231,10 @@ impl<B: Backend> FSRS<B> {
232231 all_weights. push ( batch. weights ) ;
233232 izip ! ( chunk, pred, true_val) . for_each ( |( item, p, y) | {
234233 let bin = item. item . r_matrix_index ( ) ;
235- let ( pred, real, count ) = r_matrix. entry ( bin) . or_insert ( ( 0.0 , 0.0 , 0.0 ) ) ;
234+ let ( pred, real, weight ) = r_matrix. entry ( bin) . or_insert ( ( 0.0 , 0.0 , 0.0 ) ) ;
236235 * pred += p;
237236 * real += y;
238- * count += 1.0 ;
237+ * weight += item . weight ;
239238 } ) ;
240239 progress_info. current += chunk. len ( ) ;
241240 if !progress ( progress_info) {
@@ -244,13 +243,13 @@ impl<B: Backend> FSRS<B> {
244243 }
245244 let rmse = ( r_matrix
246245 . values ( )
247- . map ( |( pred, real, count ) | {
248- let pred = pred / count ;
249- let real = real / count ;
250- ( pred - real) . powi ( 2 ) * count
246+ . map ( |( pred, real, weight ) | {
247+ let pred = pred / weight ;
248+ let real = real / weight ;
249+ ( pred - real) . powi ( 2 ) * weight
251250 } )
252251 . sum :: < f32 > ( )
253- / r_matrix. values ( ) . map ( |( _, _, count ) | count ) . sum :: < f32 > ( ) )
252+ / r_matrix. values ( ) . map ( |( _, _, weight ) | weight ) . sum :: < f32 > ( ) )
254253 . sqrt ( ) ;
255254 let all_retention = Tensor :: cat ( all_retention, 0 ) ;
256255 let all_labels = Tensor :: cat ( all_labels, 0 ) . float ( ) ;
@@ -282,7 +281,7 @@ impl<B: Backend> FSRS<B> {
282281 if items. is_empty ( ) {
283282 return Err ( FSRSError :: NotEnoughData ) ;
284283 }
285- let items = simple_weighted_fsrs_items ( items) ;
284+ let items = recency_weighted_fsrs_items ( items) ;
286285 let batcher = FSRSBatcher :: new ( self . device ( ) ) ;
287286 let mut all_predictions_self = vec ! [ ] ;
288287 let mut all_predictions_other = vec ! [ ] ;
@@ -499,17 +498,17 @@ mod tests {
499498 ] ) ) ?;
500499 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
501500
502- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.211007 , 0.037216 ] ) ;
501+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.212817 , 0.034676 ] ) ;
503502
504503 let fsrs = FSRS :: new ( Some ( & [ ] ) ) ?;
505504 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
506505
507- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.216286 , 0.038692 ] ) ;
506+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.217251 , 0.036590 ] ) ;
508507
509508 let fsrs = FSRS :: new ( Some ( PARAMETERS ) ) ?;
510509 let metrics = fsrs. evaluate ( items. clone ( ) , |_| true ) . unwrap ( ) ;
511510
512- assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.203049 , 0.027558 ] ) ;
511+ assert_approx_eq ( [ metrics. log_loss , metrics. rmse_bins ] , [ 0.203552 , 0.025646 ] ) ;
513512
514513 let ( self_by_other, other_by_self) = fsrs
515514 . universal_metrics ( items. clone ( ) , & DEFAULT_PARAMETERS , |_| true )
0 commit comments