Skip to content

Commit 8b74039

Browse files
Non-decreasing SInc(Hard) in same-day reviews (#376)
* Non-decreasing SInc(Hard) in same-day reviews See: https://discord.com/channels/368267295601983490/1347982145418694747/1435352142973108285 and https://forums.ankiweb.net/t/fsrs-hard-acting-as-a-again-in-learning-relearning-phase/67334 * pass ci --------- Co-authored-by: Jarrett Ye <jarrett.ye@outlook.com>
1 parent d22b53b commit 8b74039

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

src/model.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ impl<B: Backend> Model<B> {
115115
last_s
116116
* sinc
117117
.clone()
118-
.mask_where(rating.greater_equal_elem(3), sinc.clamp_min(1.0))
118+
.mask_where(rating.greater_equal_elem(2), sinc.clamp_min(1.0))
119119
}
120120

121121
fn mean_reversion(&self, new_d: Tensor<B, 1>) -> Tensor<B, 1> {
@@ -521,7 +521,7 @@ mod tests {
521521
.to_data()
522522
.to_vec::<f32>()
523523
.unwrap()
524-
.assert_approx_eq([1.596818, 2.7470093, 5.0, 8.12961]);
524+
.assert_approx_eq([1.596818, 5.0, 5.0, 8.12961]);
525525
}
526526

527527
#[test]

src/training.rs

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -621,33 +621,33 @@ mod tests {
621621
Reduction::Sum,
622622
);
623623

624-
assert_eq!(loss.clone().into_scalar().to_f32(), 4.047898);
624+
assert_eq!(loss.clone().into_scalar().to_f32(), 4.0466027);
625625
let gradients = loss.backward();
626626

627627
let w_grad = model.w.grad(&gradients).unwrap();
628628

629629
w_grad.to_data().to_vec::<f32>().unwrap().assert_approx_eq([
630630
-0.095688485,
631631
-0.0051607806,
632-
-0.00080300873,
632+
-0.0012249565,
633633
0.007462064,
634-
0.03677408,
635-
-0.084962785,
636-
0.059571628,
637-
-2.1566951,
638-
0.5738574,
639-
-2.8749206,
640-
0.7123072,
634+
0.03650761,
635+
-0.082112335,
636+
0.0593964,
637+
-2.1474836,
638+
0.57626534,
639+
-2.8751316,
640+
0.7154875,
641641
-0.028993709,
642642
0.0099172965,
643643
-0.2189217,
644644
-0.0017800558,
645645
-0.089381434,
646646
0.299141,
647-
0.0708902,
648-
-0.01219162,
649-
-0.25424173,
650-
0.27452517,
647+
0.068104014,
648+
-0.011605468,
649+
-0.25398168,
650+
0.27700496,
651651
]);
652652

653653
let config =
@@ -693,7 +693,7 @@ mod tests {
693693

694694
let penalty =
695695
model.l2_regularization(init_w.clone(), params_stddev.clone(), 512, 1000, 2.0);
696-
assert_eq!(penalty.clone().into_scalar().to_f32(), 0.6771115);
696+
assert_eq!(penalty.clone().into_scalar().to_f32(), 0.67711145);
697697

698698
let gradients = penalty.backward();
699699
let w_grad = model.w.grad(&gradients).unwrap();
@@ -757,7 +757,7 @@ mod tests {
757757
item.weights,
758758
Reduction::Sum,
759759
);
760-
assert_eq!(loss.clone().into_scalar().to_f32(), 3.76888);
760+
assert_eq!(loss.clone().into_scalar().to_f32(), 3.767796);
761761
let gradients = loss.backward();
762762
let w_grad = model.w.grad(&gradients).unwrap();
763763
w_grad
@@ -768,25 +768,25 @@ mod tests {
768768
.assert_approx_eq([
769769
-0.040530164,
770770
-0.0041278866,
771-
-0.0006833144,
771+
-0.0010157757,
772772
0.007239434,
773-
0.009416521,
774-
-0.12156768,
775-
0.039193563,
776-
-0.86553144,
777-
0.57743585,
778-
-2.571437,
779-
0.76415884,
773+
0.009321215,
774+
-0.120117955,
775+
0.039143264,
776+
-0.8628009,
777+
0.5794302,
778+
-2.5713828,
779+
0.7669307,
780780
-0.024242667,
781781
0.0,
782782
-0.16912507,
783783
-0.0017008218,
784784
-0.061857328,
785785
0.28093633,
786-
0.06636292,
787-
0.0057900245,
788-
-0.19041246,
789-
0.6214733,
786+
0.064058185,
787+
0.0063592787,
788+
-0.1903223,
789+
0.6257775,
790790
]);
791791
let grads = GradientsParams::from_grads(gradients, &model);
792792
model = optim.step(lr, model, grads);
@@ -802,9 +802,9 @@ mod tests {
802802
.to_vec::<f32>()
803803
.unwrap()
804804
.assert_approx_eq([
805-
0.2882918, 1.3726242, 2.3862023, 8.215636, 6.339949, 0.9131501, 2.940647,
806-
0.07696302, 1.7921939, 0.2464219, 0.71595156, 1.5631561, 0.001, 0.34230903,
807-
1.7282416, 0.68038, 1.7929853, 0.46259063, 0.1426339, 0.14509763, 0.1,
805+
0.2882918, 1.3726242, 2.3861322, 8.215636, 6.339965, 0.9130969, 2.940639,
806+
0.07696985, 1.7921946, 0.2464217, 0.71595186, 1.5631561, 0.001, 0.34230903,
807+
1.7282416, 0.68038, 1.7929853, 0.46258268, 0.14039303, 0.14509967, 0.1,
808808
]);
809809
}
810810

0 commit comments

Comments
 (0)