@@ -79,12 +79,14 @@ impl FSRSItem {
7979
8080#[ derive( Clone ) ]
8181pub ( crate ) struct FSRSBatcher < B : Backend > {
82- device : B :: Device ,
82+ _backend : core :: marker :: PhantomData < B > ,
8383}
8484
8585impl < B : Backend > FSRSBatcher < B > {
86- pub const fn new ( device : B :: Device ) -> Self {
87- Self { device }
86+ pub const fn new ( ) -> Self {
87+ Self {
88+ _backend : core:: marker:: PhantomData ,
89+ }
8890 }
8991}
9092
@@ -98,7 +100,7 @@ pub(crate) struct FSRSBatch<B: Backend> {
98100}
99101
100102impl < B : Backend > Batcher < B , WeightedFSRSItem , FSRSBatch < B > > for FSRSBatcher < B > {
101- fn batch ( & self , weighted_items : Vec < WeightedFSRSItem > , _device : & B :: Device ) -> FSRSBatch < B > {
103+ fn batch ( & self , weighted_items : Vec < WeightedFSRSItem > , device : & B :: Device ) -> FSRSBatch < B > {
102104 let pad_size = weighted_items
103105 . iter ( )
104106 . map ( |x| x. item . reviews . len ( ) )
@@ -123,7 +125,7 @@ impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
123125 dims : vec ! [ 1 , pad_size] ,
124126 } ,
125127 ) ,
126- & self . device ,
128+ device,
127129 ) ;
128130 let rating = Tensor :: < B , 2 > :: from_data (
129131 TensorData :: new (
@@ -132,7 +134,7 @@ impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
132134 dims : vec ! [ 1 , pad_size] ,
133135 } ,
134136 ) ,
135- & self . device ,
137+ device,
136138 ) ;
137139 ( delta_t, rating)
138140 } )
@@ -142,28 +144,24 @@ impl<B: Backend> Batcher<B, WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
142144 . iter ( )
143145 . map ( |weighted_item| {
144146 let current = weighted_item. item . current ( ) ;
145- let delta_t: Tensor < B , 1 > =
146- Tensor :: from_floats ( [ current. delta_t as f32 ] , & self . device ) ;
147+ let delta_t: Tensor < B , 1 > = Tensor :: from_floats ( [ current. delta_t as f32 ] , device) ;
147148 let label = match current. rating {
148149 1 => 0 ,
149150 _ => 1 ,
150151 } ;
151- let label: Tensor < B , 1 , Int > = Tensor :: from_ints ( [ label] , & self . device ) ;
152- let weight: Tensor < B , 1 > =
153- Tensor :: from_floats ( [ weighted_item. weight ] , & self . device ) ;
152+ let label: Tensor < B , 1 , Int > = Tensor :: from_ints ( [ label] , device) ;
153+ let weight: Tensor < B , 1 > = Tensor :: from_floats ( [ weighted_item. weight ] , device) ;
154154 ( delta_t, label, weight)
155155 } )
156156 . multiunzip ( ) ;
157157
158- let t_historys = Tensor :: cat ( time_histories, 0 )
159- . transpose ( )
160- . to_device ( & self . device ) ; // [seq_len, batch_size]
158+ let t_historys = Tensor :: cat ( time_histories, 0 ) . transpose ( ) . to_device ( device) ; // [seq_len, batch_size]
161159 let r_historys = Tensor :: cat ( rating_histories, 0 )
162160 . transpose ( )
163- . to_device ( & self . device ) ; // [seq_len, batch_size]
164- let delta_ts = Tensor :: cat ( delta_ts, 0 ) . to_device ( & self . device ) ;
165- let labels = Tensor :: cat ( labels, 0 ) . to_device ( & self . device ) ;
166- let weights = Tensor :: cat ( weights, 0 ) . to_device ( & self . device ) ;
161+ . to_device ( device) ; // [seq_len, batch_size]
162+ let delta_ts = Tensor :: cat ( delta_ts, 0 ) . to_device ( device) ;
163+ let labels = Tensor :: cat ( labels, 0 ) . to_device ( device) ;
164+ let weights = Tensor :: cat ( weights, 0 ) . to_device ( device) ;
167165
168166 // dbg!(&items[0].t_history);
169167 // dbg!(&t_historys);
@@ -329,7 +327,7 @@ mod tests {
329327 }
330328 ) ;
331329
332- let batcher = FSRSBatcher :: < Backend > :: new ( DEVICE ) ;
330+ let batcher = FSRSBatcher :: < Backend > :: new ( ) ;
333331 use burn:: data:: dataloader:: DataLoaderBuilder ;
334332 let dataloader = DataLoaderBuilder :: new ( batcher)
335333 . batch_size ( 1 )
@@ -347,7 +345,7 @@ mod tests {
347345
348346 #[ test]
349347 fn test_batcher ( ) {
350- let batcher = FSRSBatcher :: < Backend > :: new ( DEVICE ) ;
348+ let batcher = FSRSBatcher :: < Backend > :: new ( ) ;
351349 let items = [
352350 FSRSItem {
353351 reviews : [ ( 4 , 0 ) , ( 3 , 5 ) ]
0 commit comments