Skip to content

Commit 62d187e

Browse files
committed
Auto-merge for v2.0.0
2 parents b5e8ddc + b06352d commit 62d187e

File tree

11 files changed

+328
-137
lines changed

11 files changed

+328
-137
lines changed

.github/workflows/check.sh

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
set -eux -o pipefail
44

5-
cargo fmt --check || (
6-
echo
7-
echo "Please run 'cargo fmt' to format the code."
8-
exit 1
9-
)
5+
cargo fmt --check
106

117
cargo clippy -- -Dwarnings
128

@@ -15,5 +11,7 @@ pushd tests/data/
1511
wget https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip
1612
unzip *.zip
1713

14+
RUSTDOCFLAGS="-D warnings" cargo doc --release
15+
1816
cargo install cargo-llvm-cov --locked
1917
SKIP_TRAINING=1 cargo llvm-cov --release

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fsrs"
3-
version = "1.4.9"
3+
version = "2.0.0"
44
authors = ["Open Spaced Repetition"]
55
categories = ["algorithms", "science"]
66
edition = "2021"

examples/optimize.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
1818
println!("Default parameters: {:?}", DEFAULT_PARAMETERS);
1919

2020
// Optimize the FSRS model using the created items
21-
let optimized_parameters = fsrs.compute_parameters(fsrs_items, None)?;
21+
let optimized_parameters = fsrs.compute_parameters(fsrs_items, None, false)?;
2222

2323
println!("Optimized parameters: {:?}", optimized_parameters);
2424

src/batch_shuffle.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,22 @@ mod tests {
105105
backend::{ndarray::NdArrayDevice, NdArray},
106106
tensor::Shape,
107107
};
108+
use itertools::Itertools;
108109

109110
use super::*;
110111
use crate::{
111-
convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::prepare_training_data,
112+
convertor_tests::anki21_sample_file_converted_to_fsrs,
113+
dataset::{constant_weighted_fsrs_items, prepare_training_data},
112114
};
113115

114116
#[test]
115117
fn test_simple_dataloader() {
116-
let train_set = anki21_sample_file_converted_to_fsrs();
118+
let train_set = anki21_sample_file_converted_to_fsrs()
119+
.into_iter()
120+
.sorted_by_cached_key(|item| item.reviews.len())
121+
.collect();
117122
let (_pre_train_set, train_set) = prepare_training_data(train_set);
118-
let dataset = FSRSDataset::from(train_set);
123+
let dataset = FSRSDataset::from(constant_weighted_fsrs_items(train_set));
119124
let batch_size = 512;
120125
let seed = 114514;
121126
let device = NdArrayDevice::Cpu;

src/convertor_tests.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::convertor_tests::RevlogReviewKind::*;
2-
use crate::dataset::FSRSBatcher;
2+
use crate::dataset::{constant_weighted_fsrs_items, FSRSBatcher};
33
use crate::dataset::{FSRSItem, FSRSReview};
44
// use crate::optimal_retention::{RevlogEntry, RevlogReviewKind};
55
use crate::test_helpers::NdArrayAutodiff;
@@ -106,7 +106,7 @@ fn convert_to_date(timestamp: i64, minute_offset: i32) -> NaiveDate {
106106
fn convert_to_fsrs_items(
107107
mut entries: Vec<RevlogEntry>,
108108
minute_offset: i32,
109-
) -> Option<Vec<FSRSItem>> {
109+
) -> Option<Vec<(i64, FSRSItem)>> {
110110
// entries = filter_out_cram(entries);
111111
// entries = filter_out_manual(entries);
112112
entries = remove_revlog_before_last_first_learn(entries);
@@ -122,7 +122,7 @@ fn convert_to_fsrs_items(
122122
.iter()
123123
.enumerate()
124124
.skip(1)
125-
.map(|(idx, _)| {
125+
.map(|(idx, entry)| {
126126
let reviews = entries
127127
.iter()
128128
.take(idx + 1)
@@ -131,9 +131,9 @@ fn convert_to_fsrs_items(
131131
delta_t: r.last_interval as u32,
132132
})
133133
.collect();
134-
FSRSItem { reviews }
134+
(entry.id, FSRSItem { reviews })
135135
})
136-
.filter(|item| item.current().delta_t > 0)
136+
.filter(|(_, item)| item.current().delta_t > 0)
137137
.collect(),
138138
)
139139
}
@@ -188,8 +188,8 @@ pub fn anki_to_fsrs(revlogs: Vec<RevlogEntry>, minute_offset: i32) -> Vec<FSRSIt
188188
.filter_map(|(_cid, entries)| convert_to_fsrs_items(entries.collect(), minute_offset))
189189
.flatten()
190190
.collect_vec();
191-
revlogs.sort_by_cached_key(|r| r.reviews.len());
192-
revlogs
191+
revlogs.sort_by_cached_key(|(id, _)| *id);
192+
revlogs.into_iter().map(|(_, item)| item).collect()
193193
}
194194

195195
/*
@@ -310,10 +310,11 @@ fn conversion_works() {
310310
);
311311
312312
// convert a subset and check it matches expectations
313-
let mut fsrs_items = single_card_revlog
313+
let fsrs_items = single_card_revlog
314314
.into_iter()
315315
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
316316
.flatten()
317+
.map(|(_, item)| item)
317318
.collect_vec();
318319
assert_eq!(
319320
fsrs_items,
@@ -441,9 +442,11 @@ fn conversion_works() {
441442
]
442443
);
443444
445+
let mut weighted_fsrs_items = constant_weighted_fsrs_items(fsrs_items);
446+
444447
let device = NdArrayDevice::Cpu;
445448
let batcher = FSRSBatcher::<NdArrayAutodiff>::new(device);
446-
let res = batcher.batch(vec![fsrs_items.pop().unwrap()]);
449+
let res = batcher.batch(vec![weighted_fsrs_items.pop().unwrap()]);
447450
assert_eq!(res.delta_ts.into_scalar(), 64.0);
448451
assert_eq!(
449452
res.r_historys.squeeze(1).to_data(),
@@ -497,7 +500,8 @@ fn delta_t_is_correct() -> Result<()> {
497500
],
498501
NEXT_DAY_AT,
499502
Tz::Asia__Shanghai
500-
),
503+
)
504+
.map(|items| items.into_iter().map(|(_, item)| item).collect_vec()),
501505
Some(vec![FSRSItem {
502506
reviews: vec![
503507
FSRSReview {
@@ -522,7 +526,8 @@ fn delta_t_is_correct() -> Result<()> {
522526
],
523527
NEXT_DAY_AT,
524528
Tz::Asia__Shanghai
525-
),
529+
)
530+
.map(|items| items.into_iter().map(|(_, item)| item).collect_vec()),
526531
Some(vec![
527532
FSRSItem {
528533
reviews: vec![

src/dataset.rs

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,19 @@ pub struct FSRSItem {
1919
pub reviews: Vec<FSRSReview>,
2020
}
2121

22+
#[derive(Debug, Clone)]
23+
pub(crate) struct WeightedFSRSItem {
24+
pub weight: f32,
25+
pub item: FSRSItem,
26+
}
27+
2228
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
2329
pub struct FSRSReview {
2430
/// 1-4
2531
pub rating: u32,
2632
/// The number of days that passed
2733
/// # Warning
28-
/// [`delta_t`] for item first(initial) review must be 0
34+
/// `delta_t` for item first(initial) review must be 0
2935
pub delta_t: u32,
3036
}
3137

@@ -88,22 +94,26 @@ pub(crate) struct FSRSBatch<B: Backend> {
8894
pub r_historys: Tensor<B, 2, Float>,
8995
pub delta_ts: Tensor<B, 1, Float>,
9096
pub labels: Tensor<B, 1, Int>,
97+
pub weights: Tensor<B, 1, Float>,
9198
}
9299

93-
impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
94-
fn batch(&self, items: Vec<FSRSItem>) -> FSRSBatch<B> {
95-
let pad_size = items
100+
impl<B: Backend> Batcher<WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
101+
fn batch(&self, weighted_items: Vec<WeightedFSRSItem>) -> FSRSBatch<B> {
102+
let pad_size = weighted_items
96103
.iter()
97-
.map(|x| x.reviews.len())
104+
.map(|x| x.item.reviews.len())
98105
.max()
99106
.expect("FSRSItem is empty")
100107
- 1;
101108

102-
let (time_histories, rating_histories) = items
109+
let (time_histories, rating_histories) = weighted_items
103110
.iter()
104-
.map(|item| {
105-
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) =
106-
item.history().map(|r| (r.delta_t, r.rating)).unzip();
111+
.map(|weighted_item| {
112+
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) = weighted_item
113+
.item
114+
.history()
115+
.map(|r| (r.delta_t, r.rating))
116+
.unzip();
107117
delta_t.resize(pad_size, 0);
108118
rating.resize(pad_size, 0);
109119
let delta_t = Tensor::from_data(
@@ -130,19 +140,23 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
130140
})
131141
.unzip();
132142

133-
let (delta_ts, labels) = items
143+
let (delta_ts, labels, weights) = weighted_items
134144
.iter()
135-
.map(|item| {
136-
let current = item.current();
137-
let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
145+
.map(|weighted_item| {
146+
let current = weighted_item.item.current();
147+
let delta_t: Tensor<B, 1> =
148+
Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
138149
let label = match current.rating {
139150
1 => 0.0,
140151
_ => 1.0,
141152
};
142-
let label = Tensor::from_data(Data::from([label.elem()]), &self.device);
143-
(delta_t, label)
153+
let label: Tensor<B, 1, Int> =
154+
Tensor::from_data(Data::from([label.elem()]), &self.device);
155+
let weight: Tensor<B, 1> =
156+
Tensor::from_data(Data::from([weighted_item.weight.elem()]), &self.device);
157+
(delta_t, label, weight)
144158
})
145-
.unzip();
159+
.multiunzip();
146160

147161
let t_historys = Tensor::cat(time_histories, 0)
148162
.transpose()
@@ -152,6 +166,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
152166
.to_device(&self.device); // [seq_len, batch_size]
153167
let delta_ts = Tensor::cat(delta_ts, 0).to_device(&self.device);
154168
let labels = Tensor::cat(labels, 0).to_device(&self.device);
169+
let weights = Tensor::cat(weights, 0).to_device(&self.device);
155170

156171
// dbg!(&items[0].t_history);
157172
// dbg!(&t_historys);
@@ -161,27 +176,28 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
161176
r_historys,
162177
delta_ts,
163178
labels,
179+
weights,
164180
}
165181
}
166182
}
167183

168184
pub(crate) struct FSRSDataset {
169-
pub(crate) items: Vec<FSRSItem>,
185+
pub(crate) items: Vec<WeightedFSRSItem>,
170186
}
171187

172-
impl Dataset<FSRSItem> for FSRSDataset {
188+
impl Dataset<WeightedFSRSItem> for FSRSDataset {
173189
fn len(&self) -> usize {
174190
self.items.len()
175191
}
176192

177-
fn get(&self, index: usize) -> Option<FSRSItem> {
193+
fn get(&self, index: usize) -> Option<WeightedFSRSItem> {
178194
// info!("get {}", index);
179195
self.items.get(index).cloned()
180196
}
181197
}
182198

183-
impl From<Vec<FSRSItem>> for FSRSDataset {
184-
fn from(items: Vec<FSRSItem>) -> Self {
199+
impl From<Vec<WeightedFSRSItem>> for FSRSDataset {
200+
fn from(items: Vec<WeightedFSRSItem>) -> Self {
185201
Self { items }
186202
}
187203
}
@@ -252,6 +268,33 @@ pub fn prepare_training_data(items: Vec<FSRSItem>) -> (Vec<FSRSItem>, Vec<FSRSIt
252268
(pretrainset.clone(), [pretrainset, trainset].concat())
253269
}
254270

271+
pub(crate) fn sort_items_by_review_length(
272+
mut weighted_items: Vec<WeightedFSRSItem>,
273+
) -> Vec<WeightedFSRSItem> {
274+
weighted_items.sort_by_cached_key(|weighted_item| weighted_item.item.reviews.len());
275+
weighted_items
276+
}
277+
278+
pub(crate) fn constant_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
279+
items
280+
.into_iter()
281+
.map(|item| WeightedFSRSItem { weight: 1.0, item })
282+
.collect()
283+
}
284+
285+
/// The input items should be sorted by the review timestamp.
286+
pub(crate) fn recency_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
287+
let length = items.len() as f32;
288+
items
289+
.into_iter()
290+
.enumerate()
291+
.map(|(idx, item)| WeightedFSRSItem {
292+
weight: 0.25 + 0.75 * (idx as f32 / length).powi(3),
293+
item,
294+
})
295+
.collect()
296+
}
297+
255298
#[cfg(test)]
256299
mod tests {
257300
use super::*;
@@ -261,19 +304,21 @@ mod tests {
261304
fn from_anki() {
262305
use burn::data::dataloader::Dataset;
263306

264-
let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs());
307+
let dataset = FSRSDataset::from(sort_items_by_review_length(constant_weighted_fsrs_items(
308+
anki21_sample_file_converted_to_fsrs(),
309+
)));
265310
assert_eq!(
266-
dataset.get(704).unwrap(),
311+
dataset.get(704).unwrap().item,
267312
FSRSItem {
268313
reviews: vec![
269314
FSRSReview {
270-
rating: 3,
271-
delta_t: 0,
315+
rating: 4,
316+
delta_t: 0
272317
},
273318
FSRSReview {
274319
rating: 3,
275-
delta_t: 1,
276-
},
320+
delta_t: 3
321+
}
277322
],
278323
}
279324
);
@@ -435,6 +480,10 @@ mod tests {
435480
],
436481
},
437482
];
483+
let items = items
484+
.into_iter()
485+
.map(|item| WeightedFSRSItem { weight: 1.0, item })
486+
.collect();
438487
let batch = batcher.batch(items);
439488
assert_eq!(
440489
batch.t_historys.to_data(),

0 commit comments

Comments
 (0)