@@ -671,6 +671,63 @@ fn test_memory_functions() {
671671 set_wired_limit ( 0 ) ;
672672}
673673
674+ #[ test]
675+ fn test_scalar_helpers_preserve_bf16_and_f16_dtype ( ) {
676+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
677+ let x = astype ( & from_slice_f32 ( & [ 1.0 , 2.0 , 3.0 , 4.0 ] , & [ 1 , 4 ] ) , dtype) ;
678+
679+ let multiplied = multiply_scalar ( & x, 2.0 ) ;
680+ eval ( & multiplied) ;
681+ assert_eq ! ( array_dtype( & multiplied) , dtype) ;
682+
683+ let divided = divide_scalar ( & x, 2.0 ) ;
684+ eval ( & divided) ;
685+ assert_eq ! ( array_dtype( & divided) , dtype) ;
686+ }
687+ }
688+
689+ #[ test]
690+ fn test_softcap_helper_preserves_bf16_and_f16_dtype ( ) {
691+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
692+ let x = astype (
693+ & from_slice_f32 ( & [ 0.0 , 10.0 , -10.0 , 50.0 , -50.0 ] , & [ 1 , 5 ] ) ,
694+ dtype,
695+ ) ;
696+ let out = crate :: utils:: softcap ( & x, 30.0 ) ;
697+ eval ( & out) ;
698+
699+ assert_eq ! ( array_shape( & out) , vec![ 1 , 5 ] ) ;
700+ assert_eq ! ( array_dtype( & out) , dtype) ;
701+ }
702+ }
703+
704+ #[ test]
705+ fn test_attention_masks_intentionally_remain_float32 ( ) {
706+ let causal = crate :: utils:: create_causal_mask ( 2 , 1 ) ;
707+ eval ( & causal) ;
708+ assert_eq ! ( array_dtype( & causal) , dtype:: FLOAT32 ) ;
709+
710+ let windowed = crate :: utils:: create_causal_mask_with_window ( 4 , 0 , Some ( 2 ) ) ;
711+ eval ( & windowed) ;
712+ assert_eq ! ( array_dtype( & windowed) , dtype:: FLOAT32 ) ;
713+
714+ let padded = crate :: utils:: create_padded_prefill_mask ( 2 , 4 , 0 ) ;
715+ eval ( & padded) ;
716+ assert_eq ! ( array_dtype( & padded) , dtype:: FLOAT32 ) ;
717+ }
718+
719+ #[ test]
720+ fn test_clip_residual_f16_widens_and_returns_f16 ( ) {
721+ let x = astype ( & from_slice_f32 ( & [ 65500.0 , 1.0 ] , & [ 1 , 2 ] ) , dtype:: FLOAT16 ) ;
722+ let y = astype ( & from_slice_f32 ( & [ 10.0 , 2.0 ] , & [ 1 , 2 ] ) , dtype:: FLOAT16 ) ;
723+
724+ let out = crate :: utils:: clip_residual_f16 ( & x, & y) ;
725+ eval ( & out) ;
726+
727+ assert_eq ! ( array_shape( & out) , vec![ 1 , 2 ] ) ;
728+ assert_eq ! ( array_dtype( & out) , dtype:: FLOAT16 ) ;
729+ }
730+
674731#[ test]
675732fn test_compiled_gelu ( ) {
676733 let x = from_slice_f32 ( & [ 0.0 , 1.0 , -1.0 , 2.0 ] , & [ 1 , 4 ] ) ;
@@ -702,6 +759,42 @@ fn test_compiled_gelu_approx() {
702759 assert ! ( item_f32( & total) > 0.0 ) ;
703760}
704761
762+ #[ test]
763+ fn test_compiled_gelu_preserves_bf16_and_f16_dtype ( ) {
764+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
765+ let x = astype ( & from_slice_f32 ( & [ 0.0 , 1.0 , -1.0 , 2.0 ] , & [ 1 , 4 ] ) , dtype) ;
766+ let out = compiled_gelu ( & x) ;
767+ eval ( & out) ;
768+
769+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 ] ) ;
770+ assert_eq ! ( array_dtype( & out) , dtype) ;
771+ }
772+ }
773+
774+ #[ test]
775+ fn test_compiled_gelu_approx_preserves_bf16_and_f16_dtype ( ) {
776+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
777+ let x = astype ( & from_slice_f32 ( & [ 0.0 , 1.0 , -1.0 , 2.0 ] , & [ 1 , 4 ] ) , dtype) ;
778+ let out = compiled_gelu_approx ( & x) ;
779+ eval ( & out) ;
780+
781+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 ] ) ;
782+ assert_eq ! ( array_dtype( & out) , dtype) ;
783+ }
784+ }
785+
786+ #[ test]
787+ fn test_compiled_gelu_topk_preserves_bf16_and_f16_dtype ( ) {
788+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
789+ let x = astype ( & from_slice_f32 ( & [ -2.0 , -1.0 , 0.5 , 4.0 ] , & [ 1 , 4 ] ) , dtype) ;
790+ let out = compiled_gelu_topk ( & x, 1.0 ) ;
791+ eval ( & out) ;
792+
793+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 ] ) ;
794+ assert_eq ! ( array_dtype( & out) , dtype) ;
795+ }
796+ }
797+
705798#[ test]
706799fn test_gelu_approx_bf16_negative_values ( ) {
707800 // Verify gelu_approx does not produce NaN for negative bf16 inputs.
@@ -763,6 +856,47 @@ fn test_compiled_geglu_activation() {
763856 assert ! ( item_f32( & total) > 0.0 ) ;
764857}
765858
859+ #[ test]
860+ fn test_compiled_geglu_preserves_bf16_and_f16_dtype ( ) {
861+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
862+ let gate = astype ( & from_slice_f32 ( & [ 1.0 , 2.0 , 3.0 , 4.0 ] , & [ 1 , 4 ] ) , dtype) ;
863+ let x = astype ( & from_slice_f32 ( & [ 0.5 , 1.0 , 1.5 , 2.0 ] , & [ 1 , 4 ] ) , dtype) ;
864+ let out = compiled_geglu_activation ( & gate, & x) ;
865+ eval ( & out) ;
866+
867+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 ] ) ;
868+ assert_eq ! ( array_dtype( & out) , dtype) ;
869+ }
870+ }
871+
872+ #[ test]
873+ fn test_compiled_geglu_approx_preserves_bf16_and_f16_dtype ( ) {
874+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
875+ let gate = astype ( & from_slice_f32 ( & [ 1.0 , 2.0 , 3.0 , 4.0 ] , & [ 1 , 4 ] ) , dtype) ;
876+ let x = astype ( & from_slice_f32 ( & [ 0.5 , 1.0 , 1.5 , 2.0 ] , & [ 1 , 4 ] ) , dtype) ;
877+ let out = compiled_geglu_approx_activation ( & gate, & x) ;
878+ eval ( & out) ;
879+
880+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 ] ) ;
881+ assert_eq ! ( array_dtype( & out) , dtype) ;
882+ }
883+ }
884+
885+ #[ test]
886+ fn test_gegelu_preserves_bf16_and_f16_dtype ( ) {
887+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
888+ let x = astype (
889+ & from_slice_f32 ( & [ -1.0 , 0.5 , 2.0 , 3.0 , -0.5 , 1.0 , 4.0 , 5.0 ] , & [ 1 , 8 ] ) ,
890+ dtype,
891+ ) ;
892+ let out = crate :: utils:: gegelu ( & x, 7.0 ) ;
893+ eval ( & out) ;
894+
895+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 ] ) ;
896+ assert_eq ! ( array_dtype( & out) , dtype) ;
897+ }
898+ }
899+
766900#[ test]
767901fn test_compiled_geglu_matches_manual ( ) {
768902 // compiled_geglu_activation(gate, x) == gelu(gate) * x
@@ -827,6 +961,21 @@ fn test_compiled_softcap_zero_input() {
827961 assert ! ( ( item_f32( & out) ) . abs( ) < 1e-5 , "softcap(0) should be 0" ) ;
828962}
829963
964+ #[ test]
965+ fn test_compiled_softcap_preserves_bf16_and_f16_dtype ( ) {
966+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
967+ let scores = astype (
968+ & from_slice_f32 ( & [ 0.0 , 10.0 , -10.0 , 50.0 , -50.0 ] , & [ 1 , 5 ] ) ,
969+ dtype,
970+ ) ;
971+ let out = compiled_softcap ( & scores, 30.0 ) ;
972+ eval ( & out) ;
973+
974+ assert_eq ! ( array_shape( & out) , vec![ 1 , 5 ] ) ;
975+ assert_eq ! ( array_dtype( & out) , dtype) ;
976+ }
977+ }
978+
830979#[ test]
831980fn test_compiled_clip_residual ( ) {
832981 let x = from_slice_f32 ( & [ 1.0 , 2.0 , 3.0 , 4.0 ] , & [ 1 , 4 ] ) ;
@@ -863,6 +1012,21 @@ fn test_compiled_softcap_sdpa_shape() {
8631012 assert_eq ! ( array_shape( & out) , vec![ 1 , 2 , 4 , 8 ] ) ;
8641013}
8651014
1015+ #[ test]
1016+ fn test_compiled_softcap_sdpa_preserves_v_dtype ( ) {
1017+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
1018+ let q = astype ( & ones ( & [ 1 , 2 , 4 , 8 ] , dtype:: FLOAT32 ) , dtype) ;
1019+ let k = astype ( & ones ( & [ 1 , 2 , 4 , 8 ] , dtype:: FLOAT32 ) , dtype) ;
1020+ let v = astype ( & ones ( & [ 1 , 2 , 4 , 8 ] , dtype:: FLOAT32 ) , dtype) ;
1021+
1022+ let out = unsafe { compiled_softcap_sdpa ( & q, & k, & v, 0.125 , 30.0 , std:: ptr:: null ( ) ) } ;
1023+ eval ( & out) ;
1024+
1025+ assert_eq ! ( array_shape( & out) , vec![ 1 , 2 , 4 , 8 ] ) ;
1026+ assert_eq ! ( array_dtype( & out) , dtype) ;
1027+ }
1028+ }
1029+
8661030#[ test]
8671031fn test_compiled_softcap_sdpa_gqa_shape ( ) {
8681032 // Verify compiled_softcap_sdpa_gqa: Q has n_heads, K/V have n_kv_heads
@@ -878,6 +1042,22 @@ fn test_compiled_softcap_sdpa_gqa_shape() {
8781042 assert_eq ! ( array_shape( & out) , vec![ 1 , 4 , 2 , 8 ] ) ;
8791043}
8801044
1045+ #[ test]
1046+ fn test_compiled_softcap_sdpa_gqa_preserves_v_dtype ( ) {
1047+ for dtype in [ dtype:: BFLOAT16 , dtype:: FLOAT16 ] {
1048+ let q = astype ( & ones ( & [ 1 , 4 , 2 , 8 ] , dtype:: FLOAT32 ) , dtype) ;
1049+ let k = astype ( & ones ( & [ 1 , 2 , 2 , 8 ] , dtype:: FLOAT32 ) , dtype) ;
1050+ let v = astype ( & ones ( & [ 1 , 2 , 2 , 8 ] , dtype:: FLOAT32 ) , dtype) ;
1051+
1052+ let out =
1053+ unsafe { compiled_softcap_sdpa_gqa ( & q, & k, & v, 0.125 , 30.0 , 2 , std:: ptr:: null ( ) ) } ;
1054+ eval ( & out) ;
1055+
1056+ assert_eq ! ( array_shape( & out) , vec![ 1 , 4 , 2 , 8 ] ) ;
1057+ assert_eq ! ( array_dtype( & out) , dtype) ;
1058+ }
1059+ }
1060+
8811061#[ test]
8821062fn test_unified_linear_quantized_weight_accessor ( ) {
8831063 use crate :: layers:: { QuantizedWeight , UnifiedLinear } ;
0 commit comments