@@ -5,7 +5,6 @@ extern crate intel_mkl_src;
55extern crate accelerate_src;
66
77use candle:: { test_device, test_utils:: to_vec3_round, Device , IndexOp , Result , Tensor } ;
8- use candle_nn:: Activation ;
98
109fn softmax ( device : & Device ) -> Result < ( ) > {
1110 let data = & [ [ [ 3f32 , 1. , 4. ] , [ 1. , 5. , 9. ] ] , [ [ 2. , 1. , 7. ] , [ 8. , 2. , 8. ] ] ] ;
@@ -53,22 +52,6 @@ fn softmax(device: &Device) -> Result<()> {
5352 Ok ( ( ) )
5453}
5554
56- fn inplace_softmax ( device : & Device ) -> Result < ( ) > {
57- let data = & [ [ [ 3f32 , 1. , 4. ] , [ 1. , 5. , 9. ] ] , [ [ 2. , 1. , 7. ] , [ 8. , 2. , 8. ] ] ] ;
58- let mut tensor = Tensor :: new ( data, device) ?. log ( ) ?;
59- candle_nn:: ops:: inplace_softmax_last_dim ( & mut tensor) ?;
60- assert_eq ! (
61- to_vec3_round( & tensor, 4 ) ?,
62- & [
63- // (3, 1, 4) / 8, (1, 5, 9) / 15
64- [ [ 0.375 , 0.125 , 0.5 ] , [ 0.0667 , 0.3333 , 0.6 ] ] ,
65- // (2, 1, 7) / 10, (8, 2, 8) / 18
66- [ [ 0.2 , 0.1 , 0.7 ] , [ 0.4444 , 0.1111 , 0.4444 ] ]
67- ]
68- ) ;
69- Ok ( ( ) )
70- }
71-
7255fn rms_norm ( device : & Device ) -> Result < ( ) > {
7356 let data = & [ [ [ 3f32 , 1. , 4. ] , [ 1. , 5. , 9. ] ] , [ [ 2. , 1. , 7. ] , [ 8. , 2. , 8. ] ] ] ;
7457 let tensor = Tensor :: new ( data, device) ?;
@@ -341,44 +324,12 @@ fn sigmoid(device: &Device) -> Result<()> {
341324 Ok ( ( ) )
342325}
343326
344- fn mul_and_act ( device : & Device ) -> Result < ( ) > {
345- let data = & [ [ [ 3f32 , 1. , 4. ] , [ 1. , 5. , 9. ] ] , [ [ 2. , 1. , 7. ] , [ 8. , 2. , 8. ] ] ] ;
346- let cpu = Tensor :: new ( data, & Device :: Cpu ) ?;
347- let x = Tensor :: new ( data, device) ?;
348-
349- for act in [ Activation :: Gelu , Activation :: Relu , Activation :: Silu ] {
350- let truth = candle_nn:: ops:: mul_and_act ( & cpu, & cpu, act) ?;
351- let test = candle_nn:: ops:: mul_and_act ( & x, & x, act) ?. to_device ( & Device :: Cpu ) ?;
352-
353- let sum_diff = ( truth - test) ?. abs ( ) ?. sum_all ( ) ?. to_vec0 :: < f32 > ( ) ?;
354- if device. is_cpu ( ) {
355- assert_eq ! ( sum_diff, 0. , "act = {act:?}" ) ;
356- } else {
357- assert ! ( sum_diff < 3e-3 , "act = {act:?}" ) ;
358- }
359- }
360-
361- Ok ( ( ) )
362- }
363-
364327test_device ! ( ropei, ropei_cpu, ropei_gpu, ropei_metal) ;
365328test_device ! ( rope, rope_cpu, rope_gpu, rope_metal) ;
366329test_device ! ( rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal) ;
367330test_device ! ( softmax, softmax_cpu, softmax_gpu, softmax_metal) ;
368- test_device ! (
369- inplace_softmax,
370- inplace_softmax_cpu,
371- inplace_softmax_gpu,
372- inplace_softmax_metal
373- ) ;
374331test_device ! ( rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal) ;
375332test_device ! ( rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal) ;
376333test_device ! ( layer_norm, ln_cpu, ln_gpu, ln_metal) ;
377334test_device ! ( layer_norml, lnl_cpu, lnl_gpu, lnl_metal) ;
378335test_device ! ( sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal) ;
379- test_device ! (
380- mul_and_act,
381- mul_and_act_cpu,
382- mul_and_act_gpu,
383- mul_and_act_metal
384- ) ;
0 commit comments