Skip to content

Commit f1c1371

Browse files
committed
don't sort by length of reviews at first
1 parent edf60be commit f1c1371

File tree

4 files changed

+38
-20
lines changed

4 files changed

+38
-20
lines changed

src/batch_shuffle.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ mod tests {
105105
backend::{ndarray::NdArrayDevice, NdArray},
106106
tensor::Shape,
107107
};
108+
use itertools::Itertools;
108109

109110
use super::*;
110111
use crate::{
@@ -114,7 +115,10 @@ mod tests {
114115

115116
#[test]
116117
fn test_simple_dataloader() {
117-
let train_set = anki21_sample_file_converted_to_fsrs();
118+
let train_set = anki21_sample_file_converted_to_fsrs()
119+
.into_iter()
120+
.sorted_by_cached_key(|item| item.reviews.len())
121+
.collect();
118122
let (_pre_train_set, train_set) = prepare_training_data(train_set);
119123
let dataset = FSRSDataset::from(simple_weighted_fsrs_items(train_set));
120124
let batch_size = 512;

src/convertor_tests.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ fn convert_to_fsrs_items(
9494
mut entries: Vec<RevlogEntry>,
9595
next_day_starts_at: i64,
9696
timezone: Tz,
97-
) -> Option<Vec<FSRSItem>> {
97+
) -> Option<Vec<(i64, FSRSItem)>> {
9898
// entries = filter_out_cram(entries);
9999
// entries = filter_out_manual(entries);
100100
entries = remove_revlog_before_last_first_learn(entries);
@@ -110,7 +110,7 @@ fn convert_to_fsrs_items(
110110
.iter()
111111
.enumerate()
112112
.skip(1)
113-
.map(|(idx, _)| {
113+
.map(|(idx, entry)| {
114114
let reviews = entries
115115
.iter()
116116
.take(idx + 1)
@@ -119,9 +119,9 @@ fn convert_to_fsrs_items(
119119
delta_t: r.last_interval.max(0) as u32,
120120
})
121121
.collect();
122-
FSRSItem { reviews }
122+
(entry.id, FSRSItem { reviews })
123123
})
124-
.filter(|item| item.current().delta_t > 0)
124+
.filter(|(_, item)| item.current().delta_t > 0)
125125
.collect(),
126126
)
127127
}
@@ -137,8 +137,8 @@ pub(crate) fn anki_to_fsrs(revlogs: Vec<RevlogEntry>) -> Vec<FSRSItem> {
137137
})
138138
.flatten()
139139
.collect_vec();
140-
revlogs.sort_by_cached_key(|r| r.reviews.len());
141-
revlogs
140+
revlogs.sort_by_cached_key(|(id, _)| *id);
141+
revlogs.into_iter().map(|(_, item)| item).collect()
142142
}
143143

144144
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@@ -260,6 +260,7 @@ fn conversion_works() {
260260
.into_iter()
261261
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
262262
.flatten()
263+
.map(|(_, item)| item)
263264
.collect_vec();
264265
assert_eq!(
265266
fsrs_items,
@@ -445,7 +446,8 @@ fn delta_t_is_correct() -> Result<()> {
445446
],
446447
NEXT_DAY_AT,
447448
Tz::Asia__Shanghai
448-
),
449+
)
450+
.map(|items| items.into_iter().map(|(_, item)| item).collect_vec()),
449451
Some(vec![FSRSItem {
450452
reviews: vec![
451453
FSRSReview {
@@ -470,7 +472,8 @@ fn delta_t_is_correct() -> Result<()> {
470472
],
471473
NEXT_DAY_AT,
472474
Tz::Asia__Shanghai
473-
),
475+
)
476+
.map(|items| items.into_iter().map(|(_, item)| item).collect_vec()),
474477
Some(vec![
475478
FSRSItem {
476479
reviews: vec![

src/dataset.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ pub fn prepare_training_data(items: Vec<FSRSItem>) -> (Vec<FSRSItem>, Vec<FSRSIt
265265
(pretrainset.clone(), [pretrainset, trainset].concat())
266266
}
267267

268+
pub(crate) fn sort_items_by_review_length(items: Vec<WeightedFSRSItem>) -> Vec<WeightedFSRSItem> {
269+
let mut items = items;
270+
items.sort_by_cached_key(|item| item.item.reviews.len());
271+
items
272+
}
273+
268274
pub(crate) fn simple_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
269275
items
270276
.into_iter()
@@ -294,21 +300,21 @@ mod tests {
294300
fn from_anki() {
295301
use burn::data::dataloader::Dataset;
296302

297-
let dataset = FSRSDataset::from(simple_weighted_fsrs_items(
303+
let dataset = FSRSDataset::from(sort_items_by_review_length(simple_weighted_fsrs_items(
298304
anki21_sample_file_converted_to_fsrs(),
299-
));
305+
)));
300306
assert_eq!(
301307
dataset.get(704).unwrap().item,
302308
FSRSItem {
303309
reviews: vec![
304310
FSRSReview {
305-
rating: 3,
306-
delta_t: 0,
311+
rating: 4,
312+
delta_t: 0
307313
},
308314
FSRSReview {
309315
rating: 3,
310-
delta_t: 1,
311-
},
316+
delta_t: 3
317+
}
312318
],
313319
}
314320
);

src/training.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use crate::batch_shuffle::{BatchTensorDataset, ShuffleDataLoader};
22
use crate::cosine_annealing::CosineAnnealingLR;
3-
use crate::dataset::{prepare_training_data, recency_weighted_fsrs_items, FSRSDataset, FSRSItem};
3+
use crate::dataset::{
4+
prepare_training_data, recency_weighted_fsrs_items, sort_items_by_review_length, FSRSDataset,
5+
FSRSItem,
6+
};
47
use crate::error::Result;
58
use crate::model::{Model, ModelConfig};
69
use crate::parameter_clipper::parameter_clipper;
@@ -238,7 +241,6 @@ impl<B: Backend> FSRS<B> {
238241
AdamConfig::new().with_epsilon(1e-8),
239242
);
240243
train_set.retain(|item| item.reviews.len() <= config.max_seq_len);
241-
train_set.sort_by_cached_key(|item| item.reviews.len());
242244

243245
if let Some(progress) = &progress {
244246
let progress_state = ProgressState {
@@ -308,7 +310,6 @@ impl<B: Backend> FSRS<B> {
308310
AdamConfig::new().with_epsilon(1e-8),
309311
);
310312
train_set.retain(|item| item.reviews.len() <= config.max_seq_len);
311-
train_set.sort_by_cached_key(|item| item.reviews.len());
312313
let model =
313314
train::<Autodiff<B>>(train_set.clone(), train_set, &config, self.device(), None);
314315
let parameters: Vec<f32> = model.unwrap().w.val().to_data().convert().value;
@@ -328,14 +329,18 @@ fn train<B: AutodiffBackend>(
328329
// Training data
329330
let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs;
330331
let batch_dataset = BatchTensorDataset::<B>::new(
331-
FSRSDataset::from(recency_weighted_fsrs_items(train_set)),
332+
FSRSDataset::from(sort_items_by_review_length(recency_weighted_fsrs_items(
333+
train_set,
334+
))),
332335
config.batch_size,
333336
device.clone(),
334337
);
335338
let dataloader_train = ShuffleDataLoader::new(batch_dataset, config.seed);
336339

337340
let batch_dataset = BatchTensorDataset::<B::InnerBackend>::new(
338-
FSRSDataset::from(recency_weighted_fsrs_items(test_set.clone())),
341+
FSRSDataset::from(sort_items_by_review_length(recency_weighted_fsrs_items(
342+
test_set.clone(),
343+
))),
339344
config.batch_size,
340345
device,
341346
);

0 commit comments

Comments
 (0)