Skip to content

Commit 9fda48b

Browse files
authored
Fix/clamp the minimum of w[19] (#354)
* Fix/clamp the minimum of w[19] same to open-spaced-repetition/fsrs-optimizer#186 * bump version * only clamp when enable short term * revert unit test update * Refactor parameter clipping logic to remove Option type for enable_short_term, simplifying the function signature and internal logic.
1 parent ad5f69b commit 9fda48b

File tree

5 files changed

+33
-11
lines changed

5 files changed

+33
-11
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "fsrs"
3-
version = "5.0.0"
3+
version = "5.0.1"
44
authors = ["Open Spaced Repetition"]
55
categories = ["algorithms", "science"]
66
edition = "2024"

src/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ pub(crate) fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<
273273
let mut model = Model::new(config.clone());
274274
model.w = Param::from_tensor(Tensor::from_floats(
275275
TensorData::new(
276-
clip_parameters(parameters, config.num_relearning_steps),
276+
clip_parameters(parameters, config.num_relearning_steps, Default::default()),
277277
Shape { dims: vec![21] },
278278
),
279279
&B::Device::default(),

src/parameter_clipper.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,26 @@ use burn::{
1111
pub(crate) fn parameter_clipper<B: Backend>(
1212
parameters: Param<Tensor<B, 1>>,
1313
num_relearning_steps: usize,
14+
enable_short_term: bool,
1415
) -> Param<Tensor<B, 1>> {
1516
let (id, val) = parameters.consume();
16-
let clipped = clip_parameters(&val.to_data().to_vec().unwrap(), num_relearning_steps);
17+
let clipped = clip_parameters(
18+
&val.to_data().to_vec().unwrap(),
19+
num_relearning_steps,
20+
enable_short_term,
21+
);
1722
Param::initialized(
1823
id,
1924
Tensor::from_data(TensorData::new(clipped, val.shape()), &B::Device::default())
2025
.require_grad(),
2126
)
2227
}
2328

24-
pub(crate) fn clip_parameters(parameters: &Parameters, num_relearning_steps: usize) -> Vec<f32> {
29+
pub(crate) fn clip_parameters(
30+
parameters: &Parameters,
31+
num_relearning_steps: usize,
32+
enable_short_term: bool,
33+
) -> Vec<f32> {
2534
let mut parameters = parameters.to_vec();
2635
// PLS = w11 * D ^ -w12 * [(S + 1) ^ w13 - 1] * e ^ (w14 * (1 - R))
2736
// PLS * e ^ (num_relearning_steps * w17 * w18) should be <= S
@@ -38,6 +47,7 @@ pub(crate) fn clip_parameters(parameters: &Parameters, num_relearning_steps: usi
3847
} else {
3948
2.0
4049
};
50+
let w19_floor = if enable_short_term { 0.01 } else { 0.0 };
4151
// https://regex101.com/r/21mXNI/1
4252
let clamps: [(f32, f32); 21] = [
4353
(S_MIN, INIT_S_MAX),
@@ -59,7 +69,7 @@ pub(crate) fn clip_parameters(parameters: &Parameters, num_relearning_steps: usi
5969
(1.0, 6.0),
6070
(0.0, w17_w18_ceiling),
6171
(0.0, w17_w18_ceiling),
62-
(0.0, 0.8),
72+
(w19_floor, 0.8),
6373
(0.1, 0.8),
6474
];
6575

@@ -84,7 +94,7 @@ mod tests {
8494
&device,
8595
);
8696

87-
let param = parameter_clipper(Param::from_tensor(tensor), 1);
97+
let param = parameter_clipper(Param::from_tensor(tensor), 1, true);
8898
let values = &param.to_data().to_vec::<f32>().unwrap();
8999

90100
assert_eq!(
@@ -99,7 +109,7 @@ mod tests {
99109
let device = NdArrayDevice::Cpu;
100110
let tensor = Tensor::from_floats(DEFAULT_PARAMETERS, &device);
101111

102-
let param = parameter_clipper(Param::from_tensor(tensor), 2);
112+
let param = parameter_clipper(Param::from_tensor(tensor), 2, true);
103113
let values = &param.to_data().to_vec::<f32>().unwrap();
104114

105115
values[17..=19].assert_approx_eq([0.5425, 0.0912, 0.0658]);

src/training.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ fn train<B: AutodiffBackend>(
473473
}
474474
let grads = GradientsParams::from_grads(gradients, &model);
475475
model = optim.step(lr, model, grads);
476-
model.w = parameter_clipper(model.w, config.model.num_relearning_steps);
476+
model.w = parameter_clipper(
477+
model.w,
478+
config.model.num_relearning_steps,
479+
!config.model.freeze_short_term_stability,
480+
);
477481
// info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr);
478482
renderer.render_train(TrainingProgress {
479483
progress,
@@ -653,7 +657,11 @@ mod tests {
653657
let lr = 0.04;
654658
let grads = GradientsParams::from_grads(gradients, &model);
655659
model = optim.step(lr, model, grads);
656-
model.w = parameter_clipper(model.w, config.model.num_relearning_steps);
660+
model.w = parameter_clipper(
661+
model.w,
662+
config.model.num_relearning_steps,
663+
!config.model.freeze_short_term_stability,
664+
);
657665
model
658666
.w
659667
.val()
@@ -783,7 +791,11 @@ mod tests {
783791
]);
784792
let grads = GradientsParams::from_grads(gradients, &model);
785793
model = optim.step(lr, model, grads);
786-
model.w = parameter_clipper(model.w, config.model.num_relearning_steps);
794+
model.w = parameter_clipper(
795+
model.w,
796+
config.model.num_relearning_steps,
797+
!config.model.freeze_short_term_stability,
798+
);
787799
model
788800
.w
789801
.val()

0 commit comments

Comments
 (0)