22# SPDX-License-Identifier: Apache-2.0
33
44import argparse
5- import random
65
76import numpy as np
87from neuronpy .core .language import bfloat16
98
9+ from autotune .cache .visualize import plot_metric
1010from autotune .core .benchmark import Benchmark
1111from autotune .core .gemm_config import generate_gemm_configs
1212from autotune .core .job import ProfileJobs
@@ -32,15 +32,15 @@ def add_jobs(all_jobs: ProfileJobs, transposed_lhs: bool = False):
3232 meta_kernel = ("/home/ec2-user/workplace/nki-autotune/autotune/core/gemm.py" , "lhs_rhs_meta_gemm" )
3333
3434 # for M, N, K in [(4096, 4096, 4096), (8192, 8192, 8192), (16384, 16384, 16384), (24576, 24576, 24576)]:
35- for M , N , K in [(128 , 512 , 1279 )]:
35+ for M , N , K in [(3757 , 1647 , 2539 )]:
3636 if transposed_lhs :
3737 lhs_shape = (K , M )
3838 else :
3939 lhs_shape = (M , K )
4040 rhs_shape = (K , N )
4141 # Generate all possible configurations using the new function
4242 configs = generate_gemm_configs (M = M , N = N , K = K )
43- configs = random .sample (configs , 1 )
43+ # configs = random.sample(configs, 1)
4444 for config in configs :
4545 all_jobs .add_job (
4646 kernel = meta_kernel ,
@@ -50,14 +50,14 @@ def add_jobs(all_jobs: ProfileJobs, transposed_lhs: bool = False):
5050 compiler_flags = "--target=trn1 --auto-cast=none --internal-tensorizer-opt-level=nki" ,
5151 postprocessing = postprocessing ,
5252 )
53- # all_jobs.add_job(
54- # kernel=baseline_kernel,
55- # input_tensor_shapes=[lhs_shape, rhs_shape],
56- # data_type=data_type,
57- # kernel_kwargs={},
58- # compiler_flags="--target=trn1 --auto-cast=none --model-type=transformer --tensorizer-options='--print-nki'",
59- # postprocessing=postprocessing,
60- # )
53+ all_jobs .add_job (
54+ kernel = baseline_kernel ,
55+ input_tensor_shapes = [lhs_shape , rhs_shape ],
56+ data_type = data_type ,
57+ kernel_kwargs = {},
58+ compiler_flags = "--target=trn1 --auto-cast=none --model-type=transformer --tensorizer-options='--print-nki'" ,
59+ postprocessing = postprocessing ,
60+ )
6161
6262
6363if __name__ == "__main__" :
@@ -80,11 +80,11 @@ def add_jobs(all_jobs: ProfileJobs, transposed_lhs: bool = False):
8080 tuner = Benchmark (jobs = all_jobs , cache_root_dir = args .cache_dir )
8181 tuner ()
8282
83- # if args.mode == "lhsT_rhs" or args.mode == "both":
84- # kernel_names = ["lhsT_rhs_gemm_np", "lhsT_rhs_meta_gemm"]
85- # plot_metric(args.cache_dir, "min_ms", kernel_names)
86- # plot_metric(args.cache_dir, "mfu_estimated_percent", kernel_names)
87- # if args.mode == "lhs_rhs" or args.mode == "both":
88- # kernel_names = ["lhs_rhs_gemm_np", "lhs_rhs_meta_gemm"]
89- # plot_metric(args.cache_dir, "min_ms", kernel_names)
90- # plot_metric(args.cache_dir, "mfu_estimated_percent", kernel_names)
83+ if args .mode == "lhsT_rhs" or args .mode == "both" :
84+ kernel_names = ["lhsT_rhs_gemm_np" , "lhsT_rhs_meta_gemm" ]
85+ plot_metric (args .cache_dir , "min_ms" , kernel_names )
86+ plot_metric (args .cache_dir , "mfu_estimated_percent" , kernel_names )
87+ if args .mode == "lhs_rhs" or args .mode == "both" :
88+ kernel_names = ["lhs_rhs_gemm_np" , "lhs_rhs_meta_gemm" ]
89+ plot_metric (args .cache_dir , "min_ms" , kernel_names )
90+ plot_metric (args .cache_dir , "mfu_estimated_percent" , kernel_names )
0 commit comments