1313import ctypes
1414from contextlib import contextmanager
1515from typing import Optional
16-
1716import numpy as np
17+
1818from mlir import ir
1919from mlir .dialects import transform
2020from mlir .dialects .transform .bufferization import OneShotBufferizeOp
2525 make_nd_memref_descriptor ,
2626 as_ctype ,
2727)
28+
2829from lighthouse .utils .memref import (
2930 to_ctype as memref_to_ctype ,
3031 deallocate_memrefs_on_exit ,
3132)
3233from lighthouse .pipeline .helper import apply_registered_pass , match
33- from lighthouse .workload import Workload , benchmark
34+ from lighthouse .workload import Workload , benchmark , get_bench_wrapper_schedule
3435from lighthouse .schedule .x86 import tile_and_vector_matmul
3536from ff_weight_stationary import generate_ff_payload
3637
@@ -103,8 +104,6 @@ class DistFF(Workload):
103104 where A, B, C are (M,K), (K,N), (N,K) matrices respectively.
104105 """
105106
106- payload_function_name = "payload"
107-
108107 def __init__ (self , args , P : int , R : int ):
109108 self .M = args .sizes [0 ]
110109 self .N = args .sizes [1 ]
@@ -308,8 +307,9 @@ def schedule_modules(
308307 transform .YieldOp ()
309308 func = None
310309
310+ bench_schedule = get_bench_wrapper_schedule (self )
311+
311312 tile_schedule = tile_and_vector_matmul .create ()
312- tile_schedule .body .operations [0 ].verify ()
313313
314314 main_schedule = ir .Module .create ()
315315 main_schedule .operation .attributes ["transform.with_named_sequence" ] = (
@@ -377,7 +377,7 @@ def schedule_modules(
377377 transform .PrintOp (target = mod )
378378 transform .YieldOp ()
379379
380- return [pre_schedule , tile_schedule , main_schedule ]
380+ return [pre_schedule , bench_schedule , tile_schedule , main_schedule ]
381381
382382
383383if __name__ == "__main__" :
@@ -394,7 +394,7 @@ def schedule_modules(
394394 # execute(wload, verbose=args.verbose)
395395 rprint (" Benchmark" .center (60 , "-" ))
396396 times = benchmark (
397- wload , nruns = 10 , nwarmup = 1 , check_correctness = True , verbose = args .verbose
397+ wload , nruns = 100 , nwarmup = 10 , check_correctness = True , verbose = args .verbose
398398 )
399399 # compute statistics
400400 times *= 1e6
0 commit comments