@@ -2373,10 +2373,8 @@ impl KVCache {
23732373
23742374 let k_slice = ffi:: slice ( k_int8, & [ 0 , 0 , 0 , 0 ] , & [ ks[ 0 ] , ks[ 1 ] , live_len, ks[ 3 ] ] ) ;
23752375 let v_slice = ffi:: slice ( v_int8, & [ 0 , 0 , 0 , 0 ] , & [ vs[ 0 ] , vs[ 1 ] , live_len, vs[ 3 ] ] ) ;
2376- let ks_slice =
2377- ffi:: slice ( k_scales, & [ 0 , 0 , 0 , 0 ] , & [ kss[ 0 ] , kss[ 1 ] , live_len, 1 ] ) ;
2378- let vs_slice =
2379- ffi:: slice ( v_scales, & [ 0 , 0 , 0 , 0 ] , & [ vss[ 0 ] , vss[ 1 ] , live_len, 1 ] ) ;
2376+ let ks_slice = ffi:: slice ( k_scales, & [ 0 , 0 , 0 , 0 ] , & [ kss[ 0 ] , kss[ 1 ] , live_len, 1 ] ) ;
2377+ let vs_slice = ffi:: slice ( v_scales, & [ 0 , 0 , 0 , 0 ] , & [ vss[ 0 ] , vss[ 1 ] , live_len, 1 ] ) ;
23802378
23812379 (
23822380 dequantize ( & k_slice, & ks_slice) ,
@@ -3785,8 +3783,7 @@ impl RotatingKVCache {
37853783 return ( new_keys, new_values) ;
37863784 }
37873785
3788- let ( base_k, base_v, current_seq_len) =
3789- self . visible_fp16_prefix_for_concat ( ) ;
3786+ let ( base_k, base_v, current_seq_len) = self . visible_fp16_prefix_for_concat ( ) ;
37903787
37913788 let concat_k = concatenate ( & base_k, & new_keys, 2 ) ;
37923789 let concat_v = concatenate ( & base_v, & new_values, 2 ) ;
@@ -5681,10 +5678,7 @@ mod tests {
56815678 . collect :: < Vec < _ > > ( )
56825679 } ;
56835680 assert_eq ! ( to_f32( & visible_keys) , vec![ 1.0 , 5.0 , 6.0 , 7.0 , 8.0 ] ) ;
5684- assert_eq ! (
5685- to_f32( & visible_values) ,
5686- vec![ 10.0 , 50.0 , 60.0 , 70.0 , 80.0 ]
5687- ) ;
5681+ assert_eq ! ( to_f32( & visible_values) , vec![ 10.0 , 50.0 , 60.0 , 70.0 , 80.0 ] ) ;
56885682 }
56895683
56905684 #[ test]
@@ -6190,10 +6184,8 @@ mod tests {
61906184 }
61916185 let ( q_unrot, _) = unit_token ( 42 ) ;
61926186 let q_ref = rotate_at ( & q_unrot, M ) ;
6193- let ( k_ref, v_ref) = cache_ref. update_and_fetch (
6194- rotate_at ( & unit_token ( 99 ) . 0 , M ) ,
6195- unit_token ( 99 ) . 1 ,
6196- ) ;
6187+ let ( k_ref, v_ref) =
6188+ cache_ref. update_and_fetch ( rotate_at ( & unit_token ( 99 ) . 0 , M ) , unit_token ( 99 ) . 1 ) ;
61976189 let out_ref = mlxcel_core:: causal_attention ( & q_ref, & k_ref, & v_ref, scale, 0.0 , 0 ) ;
61986190 let out_ref_f32 = to_f32 ( & out_ref) ;
61996191
@@ -6216,11 +6208,10 @@ mod tests {
62166208 // position `M` to simulate the pre-fix offset decrement.
62176209 assert_eq ! ( cache_broken. trim_front( N ) , N ) ;
62186210 let q_broken = rotate_at ( & q_unrot, M ) ;
6219- let ( k_broken, v_broken) = cache_broken. update_and_fetch (
6220- rotate_at ( & unit_token ( 99 ) . 0 , M ) ,
6221- unit_token ( 99 ) . 1 ,
6222- ) ;
6223- let out_broken = mlxcel_core:: causal_attention ( & q_broken, & k_broken, & v_broken, scale, 0.0 , 0 ) ;
6211+ let ( k_broken, v_broken) =
6212+ cache_broken. update_and_fetch ( rotate_at ( & unit_token ( 99 ) . 0 , M ) , unit_token ( 99 ) . 1 ) ;
6213+ let out_broken =
6214+ mlxcel_core:: causal_attention ( & q_broken, & k_broken, & v_broken, scale, 0.0 , 0 ) ;
62246215 let out_broken_f32 = to_f32 ( & out_broken) ;
62256216
62266217 let mut sq_err = 0.0_f64 ;
0 commit comments