@@ -69,6 +69,14 @@ pub fn next_interval(stability: f32, desired_retention: f32) -> f32 {
6969 stability / FACTOR as f32 * ( desired_retention. powf ( 1.0 / DECAY as f32 ) - 1.0 )
7070}
7171
72+ #[ derive( Default ) ]
73+ struct RMatrixValue {
74+ predicted : f32 ,
75+ actual : f32 ,
76+ count : f32 ,
77+ weight : f32 ,
78+ }
79+
7280impl < B : Backend > FSRS < B > {
7381 /// Calculate the current memory state for a given card's history of reviews.
7482 /// In the case of truncated reviews, [starting_state] can be set to the value of
@@ -219,7 +227,7 @@ impl<B: Backend> FSRS<B> {
219227 total : items. len ( ) ,
220228 } ;
221229 let model = self . model ( ) ;
222- let mut r_matrix: HashMap < ( u32 , u32 , u32 ) , ( f32 , f32 , f32 , f32 ) > = HashMap :: new ( ) ;
230+ let mut r_matrix: HashMap < ( u32 , u32 , u32 ) , RMatrixValue > = HashMap :: new ( ) ;
223231
224232 for chunk in items. chunks ( 512 ) {
225233 let batch = batcher. batch ( chunk. to_vec ( ) ) ;
@@ -231,12 +239,11 @@ impl<B: Backend> FSRS<B> {
231239 all_weights. push ( batch. weights ) ;
232240 izip ! ( chunk, pred, true_val) . for_each ( |( item, p, y) | {
233241 let bin = item. item . r_matrix_index ( ) ;
234- let ( pred, real, count, weight) =
235- r_matrix. entry ( bin) . or_insert ( ( 0.0 , 0.0 , 0.0 , 0.0 ) ) ;
236- * pred += p;
237- * real += y;
238- * count += 1.0 ;
239- * weight += item. weight ;
242+ let value = r_matrix. entry ( bin) . or_default ( ) ;
243+ value. predicted += p;
244+ value. actual += y;
245+ value. count += 1.0 ;
246+ value. weight += item. weight ;
240247 } ) ;
241248 progress_info. current += chunk. len ( ) ;
242249 if !progress ( progress_info) {
@@ -245,16 +252,13 @@ impl<B: Backend> FSRS<B> {
245252 }
246253 let rmse = ( r_matrix
247254 . values ( )
248- . map ( |( pred , real , count , weight ) | {
249- let pred = pred / count;
250- let real = real / count;
251- ( pred - real) . powi ( 2 ) * weight
255+ . map ( |v | {
256+ let pred = v . predicted / v . count ;
257+ let real = v . actual / v . count ;
258+ ( pred - real) . powi ( 2 ) * v . weight
252259 } )
253260 . sum :: < f32 > ( )
254- / r_matrix
255- . values ( )
256- . map ( |( _, _, _, weight) | weight)
257- . sum :: < f32 > ( ) )
261+ / r_matrix. values ( ) . map ( |v| v. weight ) . sum :: < f32 > ( ) )
258262 . sqrt ( ) ;
259263 let all_retention = Tensor :: cat ( all_retention, 0 ) ;
260264 let all_labels = Tensor :: cat ( all_labels, 0 ) . float ( ) ;
0 commit comments