Skip to content

Commit 38d9a86

Browse files
committed
improve recency weighting
1 parent 10e461c commit 38d9a86

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

src/dataset.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ pub(crate) fn recency_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedF
286286
.into_iter()
287287
.enumerate()
288288
.map(|(idx, item)| WeightedFSRSItem {
289-
weight: idx as f32 / length + 0.5,
289+
weight: 0.25 + 0.75 * (idx as f32 / length).powi(3),
290290
item,
291291
})
292292
.collect()

src/inference.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ impl<B: Backend> FSRS<B> {
263263
let all_retention = Tensor::cat(all_retention, 0);
264264
let all_labels = Tensor::cat(all_labels, 0).float();
265265
let all_weights = Tensor::cat(all_weights, 0);
266-
let loss = BCELoss::new().forward(all_retention, all_labels, all_weights, Reduction::Mean);
266+
let loss = BCELoss::new().forward(all_retention, all_labels, all_weights, Reduction::Auto);
267267
Ok(ModelEvaluation {
268268
log_loss: loss.to_data().value[0].elem(),
269269
rmse_bins: rmse,
@@ -502,22 +502,23 @@ mod tests {
502502
let items = [pretrainset, trainset].concat();
503503

504504
let fsrs = FSRS::new(Some(&[
505-
0.669, 1.679, 4.1355, 9.862, 7.9435, 0.9379, 1.0148, 0.1588, 1.3851, 0.1248, 0.8421,
506-
1.992, 0.153, 0.284, 2.4282, 0.2547, 3.1847, 0.2196, 0.1906,
505+
0.6032805, 1.3376843, 4.4167747, 9.933699, 7.654044, 0.78219295, 2.336606, 0.001,
506+
1.3264198, 0.12967199, 0.82880765, 1.9360433, 0.13298263, 0.27427456, 2.4304862,
507+
0.10340813, 3.108867, 0.2114512, 0.2826002,
507508
]))?;
508509
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
509510

510-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.040148]);
511+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.206160, 0.025809]);
511512

512513
let fsrs = FSRS::new(Some(&[]))?;
513514
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
514515

515-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.041336]);
516+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.223601, 0.042738]);
516517

517518
let fsrs = FSRS::new(Some(PARAMETERS))?;
518519
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
519520

520-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.029828]);
521+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.208656, 0.030946]);
521522

522523
let (self_by_other, other_by_self) = fsrs
523524
.universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)

src/training.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ impl<B: Backend> BCELoss<B> {
4545
) -> Tensor<B, 1> {
4646
let loss = (labels.clone() * retentions.clone().log()
4747
+ (-labels + 1) * (-retentions + 1).log())
48-
* weights;
48+
* weights.clone();
4949
// info!("loss: {}", &loss);
5050
match mean {
5151
Reduction::Mean => loss.mean().neg(),
5252
Reduction::Sum => loss.sum().neg(),
53-
Reduction::Auto => loss.neg(),
53+
Reduction::Auto => (loss.sum() / weights.sum()).neg(),
5454
}
5555
}
5656
}

0 commit comments

Comments
 (0)