@@ -19,13 +19,19 @@ pub struct FSRSItem {
1919 pub reviews : Vec < FSRSReview > ,
2020}
2121
22+ #[ derive( Debug , Clone ) ]
23+ pub ( crate ) struct WeightedFSRSItem {
24+ pub weight : f32 ,
25+ pub item : FSRSItem ,
26+ }
27+
2228#[ derive( Debug , Clone , Copy , Deserialize , Serialize , PartialEq , Eq ) ]
2329pub struct FSRSReview {
2430 /// 1-4
2531 pub rating : u32 ,
2632 /// The number of days that passed
2733 /// # Warning
28- /// [ `delta_t`] for item first(initial) review must be 0
34+ /// `delta_t` for item first(initial) review must be 0
2935 pub delta_t : u32 ,
3036}
3137
@@ -88,22 +94,26 @@ pub(crate) struct FSRSBatch<B: Backend> {
8894 pub r_historys : Tensor < B , 2 , Float > ,
8995 pub delta_ts : Tensor < B , 1 , Float > ,
9096 pub labels : Tensor < B , 1 , Int > ,
97+ pub weights : Tensor < B , 1 , Float > ,
9198}
9299
93- impl < B : Backend > Batcher < FSRSItem , FSRSBatch < B > > for FSRSBatcher < B > {
94- fn batch ( & self , items : Vec < FSRSItem > ) -> FSRSBatch < B > {
95- let pad_size = items
100+ impl < B : Backend > Batcher < WeightedFSRSItem , FSRSBatch < B > > for FSRSBatcher < B > {
101+ fn batch ( & self , weighted_items : Vec < WeightedFSRSItem > ) -> FSRSBatch < B > {
102+ let pad_size = weighted_items
96103 . iter ( )
97- . map ( |x| x. reviews . len ( ) )
104+ . map ( |x| x. item . reviews . len ( ) )
98105 . max ( )
99106 . expect ( "FSRSItem is empty" )
100107 - 1 ;
101108
102- let ( time_histories, rating_histories) = items
109+ let ( time_histories, rating_histories) = weighted_items
103110 . iter ( )
104- . map ( |item| {
105- let ( mut delta_t, mut rating) : ( Vec < _ > , Vec < _ > ) =
106- item. history ( ) . map ( |r| ( r. delta_t , r. rating ) ) . unzip ( ) ;
111+ . map ( |weighted_item| {
112+ let ( mut delta_t, mut rating) : ( Vec < _ > , Vec < _ > ) = weighted_item
113+ . item
114+ . history ( )
115+ . map ( |r| ( r. delta_t , r. rating ) )
116+ . unzip ( ) ;
107117 delta_t. resize ( pad_size, 0 ) ;
108118 rating. resize ( pad_size, 0 ) ;
109119 let delta_t = Tensor :: from_data (
@@ -130,19 +140,23 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
130140 } )
131141 . unzip ( ) ;
132142
133- let ( delta_ts, labels) = items
143+ let ( delta_ts, labels, weights ) = weighted_items
134144 . iter ( )
135- . map ( |item| {
136- let current = item. current ( ) ;
137- let delta_t = Tensor :: from_data ( Data :: from ( [ current. delta_t . elem ( ) ] ) , & self . device ) ;
145+ . map ( |weighted_item| {
146+ let current = weighted_item. item . current ( ) ;
147+ let delta_t: Tensor < B , 1 > =
148+ Tensor :: from_data ( Data :: from ( [ current. delta_t . elem ( ) ] ) , & self . device ) ;
138149 let label = match current. rating {
139150 1 => 0.0 ,
140151 _ => 1.0 ,
141152 } ;
142- let label = Tensor :: from_data ( Data :: from ( [ label. elem ( ) ] ) , & self . device ) ;
143- ( delta_t, label)
153+ let label: Tensor < B , 1 , Int > =
154+ Tensor :: from_data ( Data :: from ( [ label. elem ( ) ] ) , & self . device ) ;
155+ let weight: Tensor < B , 1 > =
156+ Tensor :: from_data ( Data :: from ( [ weighted_item. weight . elem ( ) ] ) , & self . device ) ;
157+ ( delta_t, label, weight)
144158 } )
145- . unzip ( ) ;
159+ . multiunzip ( ) ;
146160
147161 let t_historys = Tensor :: cat ( time_histories, 0 )
148162 . transpose ( )
@@ -152,6 +166,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
152166 . to_device ( & self . device ) ; // [seq_len, batch_size]
153167 let delta_ts = Tensor :: cat ( delta_ts, 0 ) . to_device ( & self . device ) ;
154168 let labels = Tensor :: cat ( labels, 0 ) . to_device ( & self . device ) ;
169+ let weights = Tensor :: cat ( weights, 0 ) . to_device ( & self . device ) ;
155170
156171 // dbg!(&items[0].t_history);
157172 // dbg!(&t_historys);
@@ -161,27 +176,28 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
161176 r_historys,
162177 delta_ts,
163178 labels,
179+ weights,
164180 }
165181 }
166182}
167183
168184pub ( crate ) struct FSRSDataset {
169- pub ( crate ) items : Vec < FSRSItem > ,
185+ pub ( crate ) items : Vec < WeightedFSRSItem > ,
170186}
171187
172- impl Dataset < FSRSItem > for FSRSDataset {
188+ impl Dataset < WeightedFSRSItem > for FSRSDataset {
173189 fn len ( & self ) -> usize {
174190 self . items . len ( )
175191 }
176192
177- fn get ( & self , index : usize ) -> Option < FSRSItem > {
193+ fn get ( & self , index : usize ) -> Option < WeightedFSRSItem > {
178194 // info!("get {}", index);
179195 self . items . get ( index) . cloned ( )
180196 }
181197}
182198
183- impl From < Vec < FSRSItem > > for FSRSDataset {
184- fn from ( items : Vec < FSRSItem > ) -> Self {
199+ impl From < Vec < WeightedFSRSItem > > for FSRSDataset {
200+ fn from ( items : Vec < WeightedFSRSItem > ) -> Self {
185201 Self { items }
186202 }
187203}
@@ -252,6 +268,33 @@ pub fn prepare_training_data(items: Vec<FSRSItem>) -> (Vec<FSRSItem>, Vec<FSRSIt
252268 ( pretrainset. clone ( ) , [ pretrainset, trainset] . concat ( ) )
253269}
254270
271+ pub ( crate ) fn sort_items_by_review_length (
272+ mut weighted_items : Vec < WeightedFSRSItem > ,
273+ ) -> Vec < WeightedFSRSItem > {
274+ weighted_items. sort_by_cached_key ( |weighted_item| weighted_item. item . reviews . len ( ) ) ;
275+ weighted_items
276+ }
277+
278+ pub ( crate ) fn constant_weighted_fsrs_items ( items : Vec < FSRSItem > ) -> Vec < WeightedFSRSItem > {
279+ items
280+ . into_iter ( )
281+ . map ( |item| WeightedFSRSItem { weight : 1.0 , item } )
282+ . collect ( )
283+ }
284+
285+ /// The input items should be sorted by the review timestamp.
286+ pub ( crate ) fn recency_weighted_fsrs_items ( items : Vec < FSRSItem > ) -> Vec < WeightedFSRSItem > {
287+ let length = items. len ( ) as f32 ;
288+ items
289+ . into_iter ( )
290+ . enumerate ( )
291+ . map ( |( idx, item) | WeightedFSRSItem {
292+ weight : 0.25 + 0.75 * ( idx as f32 / length) . powi ( 3 ) ,
293+ item,
294+ } )
295+ . collect ( )
296+ }
297+
255298#[ cfg( test) ]
256299mod tests {
257300 use super :: * ;
@@ -261,19 +304,21 @@ mod tests {
261304 fn from_anki ( ) {
262305 use burn:: data:: dataloader:: Dataset ;
263306
264- let dataset = FSRSDataset :: from ( anki21_sample_file_converted_to_fsrs ( ) ) ;
307+ let dataset = FSRSDataset :: from ( sort_items_by_review_length ( constant_weighted_fsrs_items (
308+ anki21_sample_file_converted_to_fsrs ( ) ,
309+ ) ) ) ;
265310 assert_eq ! (
266- dataset. get( 704 ) . unwrap( ) ,
311+ dataset. get( 704 ) . unwrap( ) . item ,
267312 FSRSItem {
268313 reviews: vec![
269314 FSRSReview {
270- rating: 3 ,
271- delta_t: 0 ,
315+ rating: 4 ,
316+ delta_t: 0
272317 } ,
273318 FSRSReview {
274319 rating: 3 ,
275- delta_t: 1 ,
276- } ,
320+ delta_t: 3
321+ }
277322 ] ,
278323 }
279324 ) ;
@@ -435,6 +480,10 @@ mod tests {
435480 ] ,
436481 } ,
437482 ] ;
483+ let items = items
484+ . into_iter ( )
485+ . map ( |item| WeightedFSRSItem { weight : 1.0 , item } )
486+ . collect ( ) ;
438487 let batch = batcher. batch ( items) ;
439488 assert_eq ! (
440489 batch. t_historys. to_data( ) ,
0 commit comments