Skip to content

Commit 10e461c

Browse files
committed
refactor complex type
1 parent e0665e8 commit 10e461c

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

src/inference.rs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ pub fn next_interval(stability: f32, desired_retention: f32) -> f32 {
6969
stability / FACTOR as f32 * (desired_retention.powf(1.0 / DECAY as f32) - 1.0)
7070
}
7171

72+
#[derive(Default)]
73+
struct RMatrixValue {
74+
predicted: f32,
75+
actual: f32,
76+
count: f32,
77+
weight: f32,
78+
}
79+
7280
impl<B: Backend> FSRS<B> {
7381
/// Calculate the current memory state for a given card's history of reviews.
7482
/// In the case of truncated reviews, [starting_state] can be set to the value of
@@ -219,7 +227,7 @@ impl<B: Backend> FSRS<B> {
219227
total: items.len(),
220228
};
221229
let model = self.model();
222-
let mut r_matrix: HashMap<(u32, u32, u32), (f32, f32, f32, f32)> = HashMap::new();
230+
let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new();
223231

224232
for chunk in items.chunks(512) {
225233
let batch = batcher.batch(chunk.to_vec());
@@ -231,12 +239,11 @@ impl<B: Backend> FSRS<B> {
231239
all_weights.push(batch.weights);
232240
izip!(chunk, pred, true_val).for_each(|(item, p, y)| {
233241
let bin = item.item.r_matrix_index();
234-
let (pred, real, count, weight) =
235-
r_matrix.entry(bin).or_insert((0.0, 0.0, 0.0, 0.0));
236-
*pred += p;
237-
*real += y;
238-
*count += 1.0;
239-
*weight += item.weight;
242+
let value = r_matrix.entry(bin).or_default();
243+
value.predicted += p;
244+
value.actual += y;
245+
value.count += 1.0;
246+
value.weight += item.weight;
240247
});
241248
progress_info.current += chunk.len();
242249
if !progress(progress_info) {
@@ -245,16 +252,13 @@ impl<B: Backend> FSRS<B> {
245252
}
246253
let rmse = (r_matrix
247254
.values()
248-
.map(|(pred, real, count, weight)| {
249-
let pred = pred / count;
250-
let real = real / count;
251-
(pred - real).powi(2) * weight
255+
.map(|v| {
256+
let pred = v.predicted / v.count;
257+
let real = v.actual / v.count;
258+
(pred - real).powi(2) * v.weight
252259
})
253260
.sum::<f32>()
254-
/ r_matrix
255-
.values()
256-
.map(|(_, _, _, weight)| weight)
257-
.sum::<f32>())
261+
/ r_matrix.values().map(|v| v.weight).sum::<f32>())
258262
.sqrt();
259263
let all_retention = Tensor::cat(all_retention, 0);
260264
let all_labels = Tensor::cat(all_labels, 0).float();

0 commit comments

Comments
 (0)