Skip to content

Commit 16d4de5

Browse files
committed
Refactor FSRSBatcher initialization to remove device parameter, enhancing code consistency across multiple files.
1 parent bce1023 commit 16d4de5

File tree

4 files changed

+24
-27
lines changed

4 files changed

+24
-27
lines changed

src/batch_shuffle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub(crate) struct BatchTensorDataset<B: Backend> {
1717
impl<B: Backend> BatchTensorDataset<B> {
1818
/// Creates a new shuffled dataset.
1919
pub fn new(dataset: FSRSDataset, batch_size: usize, device: B::Device) -> Self {
20-
let batcher = FSRSBatcher::<B>::new(device.clone());
20+
let batcher = FSRSBatcher::<B>::new();
2121
let dataset = dataset
2222
.items
2323
.chunks(batch_size)

src/convertor_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ fn test_conversion_works() {
304304
let mut weighted_fsrs_items = constant_weighted_fsrs_items(fsrs_items);
305305

306306
let device = NdArrayDevice::Cpu;
307-
let batcher = FSRSBatcher::<NdArrayAutodiff>::new(device);
307+
let batcher = FSRSBatcher::<NdArrayAutodiff>::new();
308308
let res = batcher.batch(vec![weighted_fsrs_items.pop().unwrap()], &device);
309309
assert_eq!(res.delta_ts.into_scalar(), 64.0);
310310
assert_eq!(

src/dataset.rs

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ impl FSRSItem {
7979

8080
#[derive(Clone)]
8181
pub(crate) struct FSRSBatcher<B: Backend> {
82-
device: B::Device,
82+
_backend: core::marker::PhantomData<B>,
8383
}
8484

8585
impl<B: Backend> FSRSBatcher<B> {
86-
pub const fn new(device: B::Device) -> Self {
87-
Self { device }
86+
pub const fn new() -> Self {
87+
Self {
88+
_backend: core::marker::PhantomData,
89+
}
8890
}
8991
}
9092

@@ -98,7 +100,7 @@ pub(crate) struct FSRSBatch<B: Backend> {
98100
}
99101

100102
impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
101-
fn batch(&self, weighted_items: Vec<WeightedFSRSItem>, _device: &B::Device) -> FSRSBatch<B> {
103+
fn batch(&self, weighted_items: Vec<WeightedFSRSItem>, device: &B::Device) -> FSRSBatch<B> {
102104
let pad_size = weighted_items
103105
.iter()
104106
.map(|x| x.item.reviews.len())
@@ -123,7 +125,7 @@ impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
123125
dims: vec![1, pad_size],
124126
},
125127
),
126-
&self.device,
128+
device,
127129
);
128130
let rating = Tensor::<B, 2>::from_data(
129131
TensorData::new(
@@ -132,7 +134,7 @@ impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
132134
dims: vec![1, pad_size],
133135
},
134136
),
135-
&self.device,
137+
device,
136138
);
137139
(delta_t, rating)
138140
})
@@ -142,28 +144,24 @@ impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
142144
.iter()
143145
.map(|weighted_item| {
144146
let current = weighted_item.item.current();
145-
let delta_t: Tensor<B, 1> =
146-
Tensor::from_floats([current.delta_t as f32], &self.device);
147+
let delta_t: Tensor<B, 1> = Tensor::from_floats([current.delta_t as f32], device);
147148
let label = match current.rating {
148149
1 => 0,
149150
_ => 1,
150151
};
151-
let label: Tensor<B, 1, Int> = Tensor::from_ints([label], &self.device);
152-
let weight: Tensor<B, 1> =
153-
Tensor::from_floats([weighted_item.weight], &self.device);
152+
let label: Tensor<B, 1, Int> = Tensor::from_ints([label], device);
153+
let weight: Tensor<B, 1> = Tensor::from_floats([weighted_item.weight], device);
154154
(delta_t, label, weight)
155155
})
156156
.multiunzip();
157157

158-
let t_historys = Tensor::cat(time_histories, 0)
159-
.transpose()
160-
.to_device(&self.device); // [seq_len, batch_size]
158+
let t_historys = Tensor::cat(time_histories, 0).transpose().to_device(device); // [seq_len, batch_size]
161159
let r_historys = Tensor::cat(rating_histories, 0)
162160
.transpose()
163-
.to_device(&self.device); // [seq_len, batch_size]
164-
let delta_ts = Tensor::cat(delta_ts, 0).to_device(&self.device);
165-
let labels = Tensor::cat(labels, 0).to_device(&self.device);
166-
let weights = Tensor::cat(weights, 0).to_device(&self.device);
161+
.to_device(device); // [seq_len, batch_size]
162+
let delta_ts = Tensor::cat(delta_ts, 0).to_device(device);
163+
let labels = Tensor::cat(labels, 0).to_device(device);
164+
let weights = Tensor::cat(weights, 0).to_device(device);
167165

168166
// dbg!(&items[0].t_history);
169167
// dbg!(&t_historys);
@@ -329,7 +327,7 @@ mod tests {
329327
}
330328
);
331329

332-
let batcher = FSRSBatcher::<Backend>::new(DEVICE);
330+
let batcher = FSRSBatcher::<Backend>::new();
333331
use burn::data::dataloader::DataLoaderBuilder;
334332
let dataloader = DataLoaderBuilder::new(batcher)
335333
.batch_size(1)
@@ -347,7 +345,7 @@ mod tests {
347345

348346
#[test]
349347
fn test_batcher() {
350-
let batcher = FSRSBatcher::<Backend>::new(DEVICE);
348+
let batcher = FSRSBatcher::<Backend>::new();
351349
let items = [
352350
FSRSItem {
353351
reviews: [(4, 0), (3, 5)]

src/inference.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ impl<B: Backend> FSRS<B> {
389389
}
390390
let weighted_items = recency_weighted_fsrs_items(items);
391391
let device = self.device();
392-
let batcher = FSRSBatcher::new(device.clone());
392+
let batcher = FSRSBatcher::new();
393393
let mut all_retrievability = vec![];
394394
let mut all_labels = vec![];
395395
let mut all_weights = vec![];
@@ -458,7 +458,7 @@ impl<B: Backend> FSRS<B> {
458458
}
459459
let weighted_items = constant_weighted_fsrs_items(items);
460460
let device = self.device();
461-
let batcher = FSRSBatcher::new(device.clone());
461+
let batcher = FSRSBatcher::new();
462462
let mut all_predictions_self = vec![];
463463
let mut all_predictions_other = vec![];
464464
let mut all_true_val = vec![];
@@ -521,9 +521,9 @@ fn batch_predict(items: Vec<FSRSItem>, parameters: &[f32]) -> Result<Vec<Predict
521521
}
522522
let weighted_items = constant_weighted_fsrs_items(items);
523523
let device = NdArrayDevice::Cpu;
524-
let batcher = FSRSBatcher::new(device);
524+
let batcher = FSRSBatcher::new();
525525

526-
let fsrs = FSRS::<NdArray>::new_with_backend(parameters, device)?;
526+
let fsrs = FSRS::new(parameters)?;
527527
let model = fsrs.model();
528528
let mut predicted_items = Vec::with_capacity(weighted_items.len());
529529

@@ -551,7 +551,6 @@ fn batch_predict(items: Vec<FSRSItem>, parameters: &[f32]) -> Result<Vec<Predict
551551
/// # Returns
552552
/// A ModelEvaluation containing log loss and RMSE metrics
553553
fn evaluate(predicted_items: Vec<PredictedFSRSItem>) -> Result<ModelEvaluation> {
554-
use burn::backend::NdArray;
555554
if predicted_items.is_empty() {
556555
return Err(FSRSError::NotEnoughData);
557556
}

0 commit comments

Comments
 (0)