Skip to content

Commit e0665e8

Browse files
committed
fix rmse
1 parent aeac505 commit e0665e8

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

src/inference.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ impl<B: Backend> FSRS<B> {
219219
total: items.len(),
220220
};
221221
let model = self.model();
222-
let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32)> = HashMap::new();
222+
let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32, f32)> = HashMap::new();
223223

224224
for chunk in items.chunks(512) {
225225
let batch = batcher.batch(chunk.to_vec());
@@ -231,9 +231,11 @@ impl<B: Backend> FSRS<B> {
231231
all_weights.push(batch.weights);
232232
izip!(chunk, pred, true_val).for_each(|(item, p, y)| {
233233
let bin = item.item.r_matrix_index();
234-
let (pred, real, weight) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0));
234+
let (pred, real, count, weight) =
235+
r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0, 0.0));
235236
*pred += p;
236237
*real += y;
238+
*count += 1.0;
237239
*weight += item.weight;
238240
});
239241
progress_info.current += chunk.len();
@@ -243,13 +245,16 @@ impl<B: Backend> FSRS<B> {
243245
}
244246
let rmse = (r_matrix
245247
.values()
246-
.map(|(pred, real, weight)| {
247-
let pred = pred / weight;
248-
let real = real / weight;
248+
.map(|(pred, real, count, weight)| {
249+
let pred = pred / count;
250+
let real = real / count;
249251
(pred - real).powi(2) * weight
250252
})
251253
.sum::<f32>()
252-
/ r_matrix.values().map(|(_, _, weight)| weight).sum::<f32>())
254+
/ r_matrix
255+
.values()
256+
.map(|(_, _, _, weight)| weight)
257+
.sum::<f32>())
253258
.sqrt();
254259
let all_retention = Tensor::cat(all_retention, 0);
255260
let all_labels = Tensor::cat(all_labels, 0).float();
@@ -498,17 +503,17 @@ mod tests {
498503
]))?;
499504
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
500505

501-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.034676]);
506+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.040148]);
502507

503508
let fsrs = FSRS::new(Some(&[]))?;
504509
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
505510

506-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.036590]);
511+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.041336]);
507512

508513
let fsrs = FSRS::new(Some(PARAMETERS))?;
509514
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
510515

511-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.025646]);
516+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.029828]);
512517

513518
let (self_by_other, other_by_self) = fsrs
514519
.universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)

0 commit comments

Comments
 (0)