Skip to content

Commit aeac505

Browse files
committed
apply recency weighting to evaluation
1 parent f1c1371 commit aeac505

File tree

4 files changed

+20
-20
lines changed

4 files changed

+20
-20
lines changed

src/batch_shuffle.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ mod tests {
110110
use super::*;
111111
use crate::{
112112
convertor_tests::anki21_sample_file_converted_to_fsrs,
113-
dataset::{prepare_training_data, simple_weighted_fsrs_items},
113+
dataset::{constant_weighted_fsrs_items, prepare_training_data},
114114
};
115115

116116
#[test]
@@ -120,7 +120,7 @@ mod tests {
120120
.sorted_by_cached_key(|item| item.reviews.len())
121121
.collect();
122122
let (_pre_train_set, train_set) = prepare_training_data(train_set);
123-
let dataset = FSRSDataset::from(simple_weighted_fsrs_items(train_set));
123+
let dataset = FSRSDataset::from(constant_weighted_fsrs_items(train_set));
124124
let batch_size = 512;
125125
let seed = 114514;
126126
let device = NdArrayDevice::Cpu;

src/convertor_tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::convertor_tests::RevlogReviewKind::*;
2-
use crate::dataset::{simple_weighted_fsrs_items, FSRSBatcher};
2+
use crate::dataset::{constant_weighted_fsrs_items, FSRSBatcher};
33
use crate::dataset::{FSRSItem, FSRSReview};
44
use crate::optimal_retention::{RevlogEntry, RevlogReviewKind};
55
use crate::test_helpers::NdArrayAutodiff;
@@ -388,7 +388,7 @@ fn conversion_works() {
388388
]
389389
);
390390

391-
let mut weighted_fsrs_items = simple_weighted_fsrs_items(fsrs_items);
391+
let mut weighted_fsrs_items = constant_weighted_fsrs_items(fsrs_items);
392392

393393
let device = NdArrayDevice::Cpu;
394394
let batcher = FSRSBatcher::<NdArrayAutodiff>::new(device);

src/dataset.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ pub(crate) fn sort_items_by_review_length(items: Vec<WeightedFSRSItem>) -> Vec<W
271271
items
272272
}
273273

274-
pub(crate) fn simple_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
274+
#[cfg(test)]
275+
pub(crate) fn constant_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
275276
items
276277
.into_iter()
277278
.map(|item| WeightedFSRSItem { weight: 1.0, item })
@@ -300,7 +301,7 @@ mod tests {
300301
fn from_anki() {
301302
use burn::data::dataloader::Dataset;
302303

303-
let dataset = FSRSDataset::from(sort_items_by_review_length(simple_weighted_fsrs_items(
304+
let dataset = FSRSDataset::from(sort_items_by_review_length(constant_weighted_fsrs_items(
304305
anki21_sample_file_converted_to_fsrs(),
305306
)));
306307
assert_eq!(

src/inference.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ use burn::nn::loss::Reduction;
66
use burn::tensor::{Data, Shape, Tensor};
77
use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend};
88

9-
use crate::dataset::FSRSBatcher;
10-
use crate::dataset::{simple_weighted_fsrs_items, FSRSBatch};
9+
use crate::dataset::{recency_weighted_fsrs_items, FSRSBatch, FSRSBatcher};
1110
use crate::error::Result;
1211
use crate::model::Model;
1312
use crate::training::BCELoss;
@@ -210,7 +209,7 @@ impl<B: Backend> FSRS<B> {
210209
if items.is_empty() {
211210
return Err(FSRSError::NotEnoughData);
212211
}
213-
let items = simple_weighted_fsrs_items(items);
212+
let items = recency_weighted_fsrs_items(items);
214213
let batcher = FSRSBatcher::new(self.device());
215214
let mut all_retention = vec![];
216215
let mut all_labels = vec![];
@@ -232,10 +231,10 @@ impl<B: Backend> FSRS<B> {
232231
all_weights.push(batch.weights);
233232
izip!(chunk, pred, true_val).for_each(|(item, p, y)| {
234233
let bin = item.item.r_matrix_index();
235-
let (pred, real, count) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0));
234+
let (pred, real, weight) = r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0));
236235
*pred += p;
237236
*real += y;
238-
*count += 1.0;
237+
*weight += item.weight;
239238
});
240239
progress_info.current += chunk.len();
241240
if !progress(progress_info) {
@@ -244,13 +243,13 @@ impl<B: Backend> FSRS<B> {
244243
}
245244
let rmse = (r_matrix
246245
.values()
247-
.map(|(pred, real, count)| {
248-
let pred = pred / count;
249-
let real = real / count;
250-
(pred - real).powi(2) * count
246+
.map(|(pred, real, weight)| {
247+
let pred = pred / weight;
248+
let real = real / weight;
249+
(pred - real).powi(2) * weight
251250
})
252251
.sum::<f32>()
253-
/ r_matrix.values().map(|(_, _, count)| count).sum::<f32>())
252+
/ r_matrix.values().map(|(_, _, weight)| weight).sum::<f32>())
254253
.sqrt();
255254
let all_retention = Tensor::cat(all_retention, 0);
256255
let all_labels = Tensor::cat(all_labels, 0).float();
@@ -282,7 +281,7 @@ impl<B: Backend> FSRS<B> {
282281
if items.is_empty() {
283282
return Err(FSRSError::NotEnoughData);
284283
}
285-
let items = simple_weighted_fsrs_items(items);
284+
let items = recency_weighted_fsrs_items(items);
286285
let batcher = FSRSBatcher::new(self.device());
287286
let mut all_predictions_self = vec![];
288287
let mut all_predictions_other = vec![];
@@ -499,17 +498,17 @@ mod tests {
499498
]))?;
500499
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
501500

502-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.211007, 0.037216]);
501+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.212817, 0.034676]);
503502

504503
let fsrs = FSRS::new(Some(&[]))?;
505504
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
506505

507-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.216286, 0.038692]);
506+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.217251, 0.036590]);
508507

509508
let fsrs = FSRS::new(Some(PARAMETERS))?;
510509
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
511510

512-
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203049, 0.027558]);
511+
assert_approx_eq([metrics.log_loss, metrics.rmse_bins], [0.203552, 0.025646]);
513512

514513
let (self_by_other, other_by_self) = fsrs
515514
.universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)

0 commit comments

Comments
 (0)