Skip to content

Commit 163f031

Browse files
L-M-Sherlockautofix-ci[bot]gemini-code-assist[bot]
authored
Refactor/extract standalone functions from structure (#367)
* Refactor/extract standalone functions from structure * Change visibility of several functions in dataset, parameter initialization, and training modules to `pub(crate)` for better encapsulation. * Rename test functions for consistency across modules, prefixing with 'test_' to enhance clarity and maintainability. * bump version * Remove unused import of `izip` in `measure_a_by_b` function to clean up code. * make optimal_retention standalone * make evaluate_with_time_series_splits and compute_parameters standalone * [autofix.ci] apply automated fixes * Refactor FSRS initialization to remove unnecessary `Some` wrapper around parameters in multiple files, enhancing code clarity and consistency. * Refactor FSRS initialization across multiple files to use `FSRS::default()` instead of `FSRS::new(&[])`, improving code consistency and readability. * make benchmark standalone & remove redundant documents * Refactor FSRS initialization in benchmarks and tests to use `FSRS::default()`, ensuring consistency and improving code readability. * [autofix.ci] apply automated fixes * Refactor tests to use a static `DEVICE` for `NdArrayDevice::Cpu`, improving consistency and reducing redundancy in multiple test files. * [autofix.ci] apply automated fixes * Refactor FSRS and Model implementations to utilize device-specific tensor creation, enhancing device management and consistency across the codebase. * Update src/inference.rs Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Refactor FSRSBatcher initialization to remove device parameter, enhancing code consistency across multiple files. * Refactor BatchTensorDataset and MemoryStateTensors to eliminate device parameter, improving code consistency and simplifying tensor creation across multiple files. * Update src/model.rs Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Refactor device parameter handling in FSRS initialization across multiple files for improved consistency and clarity. --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 611980a commit 163f031

20 files changed

+656
-725
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fsrs"
3-
version = "5.2.0"
3+
version = "6.0.0"
44
authors = ["Open Spaced Repetition"]
55
categories = ["algorithms", "science"]
66
edition = "2024"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Read [this](https://github.com/open-spaced-repetition/fsrs4anki/wiki/The-Optimal
2020
// Pick whichever percentage is to your liking (see above)
2121
let optimal_retention = 0.75;
2222
// Use default parameters/weights for the scheduler
23-
let fsrs = FSRS::new(Some(&[]))?;
23+
let fsrs = FSRS::default();
2424

2525
// Create a completely new card
2626
let day1_states = fsrs.next_states(None, optimal_retention, 0)?;

benches/benchmark.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub(crate) fn next_states(inf: &FSRS) -> NextStates {
6060
}
6161

6262
pub fn criterion_benchmark(c: &mut Criterion) {
63-
let fsrs = FSRS::new(Some(&[
63+
let fsrs = FSRS::new(&[
6464
0.81497127,
6565
1.5411042,
6666
4.007436,
@@ -78,7 +78,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
7878
1.3384504,
7979
0.22278537,
8080
2.6646678,
81-
]))
81+
])
8282
.unwrap();
8383

8484
c.bench_function("next_states", |b| b.iter(|| black_box(next_states(&fsrs))));

benches/parameters.rs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
use criterion::{Criterion, criterion_group, criterion_main};
22
use fsrs::{
3-
// dataset::prepare_training_data, // Will be inlined
4-
// convertor_tests::anki21_sample_file_converted_to_fsrs, // Will be inlined
5-
ComputeParametersInput,
6-
DEFAULT_PARAMETERS,
7-
FSRS,
8-
FSRSItem,
9-
FSRSReview,
3+
ComputeParametersInput, FSRS, FSRSItem, FSRSReview, compute_parameters,
4+
evaluate_with_time_series_splits,
105
};
116
// Add necessary imports for inlined code
127
use chrono::prelude::*;
@@ -212,7 +207,7 @@ fn load_and_prepare_data() -> Vec<FSRSItem> {
212207
fn benchmark_evaluate(c: &mut Criterion) {
213208
let items = load_and_prepare_data();
214209
// Evaluate uses the FSRS instance's existing parameters.
215-
let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS)).unwrap();
210+
let fsrs = FSRS::default();
216211

217212
let mut group = c.benchmark_group("parameters");
218213
group.sample_size(10); // Reduce sample size if benchmarks are too long
@@ -229,7 +224,6 @@ fn benchmark_evaluate(c: &mut Criterion) {
229224
fn benchmark_evaluate_with_time_series_splits(c: &mut Criterion) {
230225
let items = load_and_prepare_data();
231226
// evaluate_with_time_series_splits computes parameters internally for each split.
232-
let fsrs = FSRS::new(None).unwrap();
233227
let input = ComputeParametersInput {
234228
train_set: items.clone(),
235229
progress: None,
@@ -242,17 +236,14 @@ fn benchmark_evaluate_with_time_series_splits(c: &mut Criterion) {
242236

243237
group.bench_function("evaluate_with_time_series_splits", |b| {
244238
b.iter(|| {
245-
fsrs.evaluate_with_time_series_splits(black_box(input.clone()), |_| true)
246-
.unwrap();
239+
evaluate_with_time_series_splits(black_box(input.clone()), |_| true).unwrap();
247240
})
248241
});
249242
group.finish();
250243
}
251244

252245
fn benchmark_compute_parameters(c: &mut Criterion) {
253246
let items = load_and_prepare_data();
254-
// compute_parameters computes new parameters, so initial FSRS parameters don't matter.
255-
let fsrs = FSRS::new(None).unwrap();
256247
let input = ComputeParametersInput {
257248
train_set: items.clone(), // Using the full prepared dataset
258249
progress: None,
@@ -265,7 +256,7 @@ fn benchmark_compute_parameters(c: &mut Criterion) {
265256

266257
group.bench_function("compute_parameters", |b| {
267258
b.iter(|| {
268-
fsrs.compute_parameters(black_box(input.clone())).unwrap();
259+
compute_parameters(black_box(input.clone())).unwrap();
269260
})
270261
});
271262

benches/simulation.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use criterion::{Criterion, criterion_group, criterion_main};
22
use fsrs::{Card, expected_workload, expected_workload_with_existing_cards};
3-
use fsrs::{DEFAULT_PARAMETERS, SimulationResult, SimulatorConfig, simulate};
3+
use fsrs::{DEFAULT_PARAMETERS, SimulationResult, SimulatorConfig, optimal_retention, simulate};
44
use fsrs::{FSRS, FSRSError};
55
use rayon::iter::IntoParallelIterator;
66
use rayon::iter::ParallelIterator;
@@ -28,9 +28,8 @@ pub(crate) fn parallel_simulate(config: &SimulatorConfig) -> Result<Vec<f32>, FS
2828
.collect()
2929
}
3030

31-
pub(crate) fn optimal_retention(inf: &FSRS, config: &SimulatorConfig) -> f32 {
32-
inf.optimal_retention(config, &[], |_v| true, None, None)
33-
.unwrap()
31+
pub(crate) fn bench_optimal_retention(_inf: &FSRS, config: &SimulatorConfig) -> f32 {
32+
optimal_retention(config, &[], |_v| true, None, None).unwrap()
3433
}
3534

3635
pub(crate) fn run_expected_workload_for_30_retentions() {
@@ -66,7 +65,7 @@ pub(crate) fn run_expected_workload_with_10000_existing_cards() {
6665
}
6766

6867
pub fn criterion_benchmark(c: &mut Criterion) {
69-
let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS)).unwrap();
68+
let fsrs = FSRS::default();
7069
let config = SimulatorConfig {
7170
deck_size: 36500,
7271
learn_span: 90,
@@ -79,7 +78,7 @@ pub fn criterion_benchmark(c: &mut Criterion) {
7978
b.iter(|| black_box(parallel_simulate(&config)))
8079
});
8180
c.bench_function("optimal_retention", |b| {
82-
b.iter(|| black_box(optimal_retention(&fsrs, &config)))
81+
b.iter(|| black_box(bench_optimal_retention(&fsrs, &config)))
8382
});
8483
c.bench_function("expected_workload_30_retentions", |b| {
8584
b.iter(run_expected_workload_for_30_retentions)

examples/migrate.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use fsrs::{FSRS, FSRSItem, FSRSReview};
22

33
fn migrate_with_full_history() -> Result<(), Box<dyn std::error::Error>> {
44
// Create a new FSRS model
5-
let fsrs = FSRS::new(Some(&[]))?;
5+
let fsrs = FSRS::default();
66

77
// Simulate a full review history for a card
88
let reviews = vec![
@@ -35,7 +35,7 @@ fn migrate_with_full_history() -> Result<(), Box<dyn std::error::Error>> {
3535

3636
fn migrate_with_partial_history() -> Result<(), Box<dyn std::error::Error>> {
3737
// Create a new FSRS model
38-
let fsrs = FSRS::new(Some(&[]))?;
38+
let fsrs = FSRS::default();
3939

4040
// Set the true retention of the original algorithm
4141
let sm2_retention = 0.9;
@@ -76,7 +76,7 @@ fn migrate_with_partial_history() -> Result<(), Box<dyn std::error::Error>> {
7676

7777
fn migrate_with_latest_state() -> Result<(), Box<dyn std::error::Error>> {
7878
// Create a new FSRS model
79-
let fsrs = FSRS::new(Some(&[]))?;
79+
let fsrs = FSRS::default();
8080

8181
// Set the true retention of the original algorithm
8282
let sm2_retention = 0.9;

examples/optimize.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use chrono::NaiveDate;
2-
use fsrs::{ComputeParametersInput, DEFAULT_PARAMETERS, FSRS, FSRSItem, FSRSReview};
2+
use fsrs::{ComputeParametersInput, FSRSItem, FSRSReview, compute_parameters};
33

44
fn main() -> Result<(), Box<dyn std::error::Error>> {
55
// Create review histories for cards
@@ -13,12 +13,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
1313

1414
println!("Size of FSRSItems: {}", fsrs_items.len());
1515

16-
// Create an FSRS instance with default parameters
17-
let fsrs = FSRS::new(Some(&[]))?;
18-
println!("Default parameters: {:?}", DEFAULT_PARAMETERS);
19-
2016
// Optimize the FSRS model using the created items
21-
let optimized_parameters = fsrs.compute_parameters(ComputeParametersInput {
17+
let optimized_parameters = compute_parameters(ComputeParametersInput {
2218
train_set: fsrs_items,
2319
..Default::default()
2420
})?;

examples/schedule.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ fn schedule_new_card() -> Result<(), Box<dyn std::error::Error>> {
2727
let desired_retention = 0.9;
2828

2929
// Create a new FSRS model
30-
let fsrs = FSRS::new(Some(&[]))?;
30+
let fsrs = FSRS::default();
3131

3232
// Get next states for a new card
3333
let next_states = fsrs.next_states(card.memory_state, desired_retention, 0)?;
@@ -81,7 +81,7 @@ fn schedule_existing_card() -> Result<(), Box<dyn std::error::Error>> {
8181
let desired_retention = 0.9;
8282

8383
// Create a new FSRS model
84-
let fsrs = FSRS::new(Some(&[]))?;
84+
let fsrs = FSRS::default();
8585

8686
// Calculate the elapsed time since the last review
8787
let elapsed_days = (Utc::now() - card.last_review.unwrap()).num_days() as u32;

src/batch_shuffle.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ pub(crate) struct BatchTensorDataset<B: Backend> {
1616

1717
impl<B: Backend> BatchTensorDataset<B> {
1818
/// Creates a new shuffled dataset.
19-
pub fn new(dataset: FSRSDataset, batch_size: usize, device: B::Device) -> Self {
20-
let batcher = FSRSBatcher::<B>::new(device.clone());
19+
pub fn new(dataset: FSRSDataset, batch_size: usize) -> Self {
20+
let device = B::Device::default();
21+
let batcher = FSRSBatcher::<B>::new();
2122
let dataset = dataset
2223
.items
2324
.chunks(batch_size)
@@ -102,10 +103,7 @@ impl<B: Backend> ShuffleDataLoader<B> {
102103

103104
#[cfg(test)]
104105
mod tests {
105-
use burn::{
106-
backend::{NdArray, ndarray::NdArrayDevice},
107-
tensor::Shape,
108-
};
106+
use burn::{backend::NdArray, tensor::Shape};
109107
use itertools::Itertools;
110108

111109
use super::*;
@@ -124,10 +122,9 @@ mod tests {
124122
let dataset = FSRSDataset::from(constant_weighted_fsrs_items(train_set));
125123
let batch_size = 512;
126124
let seed = 114514;
127-
let device = NdArrayDevice::Cpu;
128125
type Backend = NdArray<f32>;
129126

130-
let dataset = BatchTensorDataset::<Backend>::new(dataset, batch_size, device);
127+
let dataset = BatchTensorDataset::<Backend>::new(dataset, batch_size);
131128
let dataloader = ShuffleDataLoader::new(dataset, seed);
132129
let mut iterator = dataloader.iter();
133130
// dbg!(&iterator.indices);

0 commit comments

Comments
 (0)