@@ -6,28 +6,18 @@ extern crate intel_mkl_src;
66
77use anyhow:: Result ;
88use candle_core:: { Device , Tensor } ;
9-
9+ // xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
1010fn main ( ) -> Result < ( ) > {
1111 let device = Device :: new_cuda ( 0 ) ?;
12- let x = Tensor :: randn ( 0f32 , 1.0 , ( 8 * 4096 , 8 * 4096 ) , & device) ?
13- . to_dtype ( candle_core:: DType :: BF16 ) ?;
14- candle_core:: cuda:: set_gemm_reduced_precision_f32 ( false ) ;
15- candle_core:: cuda:: set_gemm_reduced_precision_bf16 ( false ) ;
16- let _x1 = x. matmul ( & x) ?;
17- drop ( _x1) ;
18- let start_time = std:: time:: Instant :: now ( ) ;
19- let _x1 = x. matmul ( & x) ?;
20- device. synchronize ( ) ?;
21- println ! ( "fp32: {:?}" , start_time. elapsed( ) ) ;
22- drop ( _x1) ;
23- candle_core:: cuda:: set_gemm_reduced_precision_f32 ( true ) ;
24- candle_core:: cuda:: set_gemm_reduced_precision_bf16 ( true ) ;
25- let _x1 = x. matmul ( & x) ?;
26- drop ( _x1) ;
27- let start_time = std:: time:: Instant :: now ( ) ;
28- let _x1 = x. matmul ( & x) ?;
29- device. synchronize ( ) ?;
30- println ! ( "tf32: {:?}" , start_time. elapsed( ) ) ;
12+ let x = Tensor :: randn ( 0f32 , 1.0 , ( 1024 , 64 , 1924 ) , & device) ?;
13+ let c = Tensor :: randn ( 0f32 , 1.0 , ( 128 , 64 , 8 ) , & device) ?;
14+ let _x1 = x. conv1d ( & c, 0 , 4 , 1 , 1 ) ?;
3115 drop ( _x1) ;
16+ for _ in 0 ..20 {
17+ let start_time = std:: time:: Instant :: now ( ) ;
18+ let _x1 = x. conv1d ( & c, 0 , 4 , 1 , 1 ) ?;
19+ device. synchronize ( ) ?;
20+ println ! ( "conv1d: {:?}" , start_time. elapsed( ) ) ;
21+ }
3222 Ok ( ( ) )
3323}
0 commit comments