@@ -621,33 +621,33 @@ mod tests {
621621 Reduction :: Sum ,
622622 ) ;
623623
624- assert_eq ! ( loss. clone( ) . into_scalar( ) . to_f32( ) , 4.047898 ) ;
624+ assert_eq ! ( loss. clone( ) . into_scalar( ) . to_f32( ) , 4.0466027 ) ;
625625 let gradients = loss. backward ( ) ;
626626
627627 let w_grad = model. w . grad ( & gradients) . unwrap ( ) ;
628628
629629 w_grad. to_data ( ) . to_vec :: < f32 > ( ) . unwrap ( ) . assert_approx_eq ( [
630630 -0.095688485 ,
631631 -0.0051607806 ,
632- -0.00080300873 ,
632+ -0.0012249565 ,
633633 0.007462064 ,
634- 0.03677408 ,
635- -0.084962785 ,
636- 0.059571628 ,
637- -2.1566951 ,
638- 0.5738574 ,
639- -2.8749206 ,
640- 0.7123072 ,
634+ 0.03650761 ,
635+ -0.082112335 ,
636+ 0.0593964 ,
637+ -2.1474836 ,
638+ 0.57626534 ,
639+ -2.8751316 ,
640+ 0.7154875 ,
641641 -0.028993709 ,
642642 0.0099172965 ,
643643 -0.2189217 ,
644644 -0.0017800558 ,
645645 -0.089381434 ,
646646 0.299141 ,
647- 0.0708902 ,
648- -0.01219162 ,
649- -0.25424173 ,
650- 0.27452517 ,
647+ 0.068104014 ,
648+ -0.011605468 ,
649+ -0.25398168 ,
650+ 0.27700496 ,
651651 ] ) ;
652652
653653 let config =
@@ -693,7 +693,7 @@ mod tests {
693693
694694 let penalty =
695695 model. l2_regularization ( init_w. clone ( ) , params_stddev. clone ( ) , 512 , 1000 , 2.0 ) ;
696- assert_eq ! ( penalty. clone( ) . into_scalar( ) . to_f32( ) , 0.6771115 ) ;
696+ assert_eq ! ( penalty. clone( ) . into_scalar( ) . to_f32( ) , 0.67711145 ) ;
697697
698698 let gradients = penalty. backward ( ) ;
699699 let w_grad = model. w . grad ( & gradients) . unwrap ( ) ;
@@ -757,7 +757,7 @@ mod tests {
757757 item. weights ,
758758 Reduction :: Sum ,
759759 ) ;
760- assert_eq ! ( loss. clone( ) . into_scalar( ) . to_f32( ) , 3.76888 ) ;
760+ assert_eq ! ( loss. clone( ) . into_scalar( ) . to_f32( ) , 3.767796 ) ;
761761 let gradients = loss. backward ( ) ;
762762 let w_grad = model. w . grad ( & gradients) . unwrap ( ) ;
763763 w_grad
@@ -768,25 +768,25 @@ mod tests {
768768 . assert_approx_eq ( [
769769 -0.040530164 ,
770770 -0.0041278866 ,
771- -0.0006833144 ,
771+ -0.0010157757 ,
772772 0.007239434 ,
773- 0.009416521 ,
774- -0.12156768 ,
775- 0.039193563 ,
776- -0.86553144 ,
777- 0.57743585 ,
778- -2.571437 ,
779- 0.76415884 ,
773+ 0.009321215 ,
774+ -0.120117955 ,
775+ 0.039143264 ,
776+ -0.8628009 ,
777+ 0.5794302 ,
778+ -2.5713828 ,
779+ 0.7669307 ,
780780 -0.024242667 ,
781781 0.0 ,
782782 -0.16912507 ,
783783 -0.0017008218 ,
784784 -0.061857328 ,
785785 0.28093633 ,
786- 0.06636292 ,
787- 0.0057900245 ,
788- -0.19041246 ,
789- 0.6214733 ,
786+ 0.064058185 ,
787+ 0.0063592787 ,
788+ -0.1903223 ,
789+ 0.6257775 ,
790790 ] ) ;
791791 let grads = GradientsParams :: from_grads ( gradients, & model) ;
792792 model = optim. step ( lr, model, grads) ;
@@ -802,9 +802,9 @@ mod tests {
802802 . to_vec :: < f32 > ( )
803803 . unwrap ( )
804804 . assert_approx_eq ( [
805- 0.2882918 , 1.3726242 , 2.3862023 , 8.215636 , 6.339949 , 0.9131501 , 2.940647 ,
806- 0.07696302 , 1.7921939 , 0.2464219 , 0.71595156 , 1.5631561 , 0.001 , 0.34230903 ,
807- 1.7282416 , 0.68038 , 1.7929853 , 0.46259063 , 0.1426339 , 0.14509763 , 0.1 ,
805+ 0.2882918 , 1.3726242 , 2.3861322 , 8.215636 , 6.339965 , 0.9130969 , 2.940639 ,
806+ 0.07696985 , 1.7921946 , 0.2464217 , 0.71595186 , 1.5631561 , 0.001 , 0.34230903 ,
807+ 1.7282416 , 0.68038 , 1.7929853 , 0.46258268 , 0.14039303 , 0.14509967 , 0.1 ,
808808 ] ) ;
809809 }
810810
0 commit comments