From ce5bbe32f6a5f790d8d67fd07d84dfd4b0ca0cb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=AE=97=E7=A5=A5?= Date: Tue, 17 Feb 2026 10:39:49 +0800 Subject: [PATCH] update tsingmicro code to 3.3.x --- third_party/tsingmicro/CMakeLists.txt | 16 +- third_party/tsingmicro/backend/compiler.py | 54 +- third_party/tsingmicro/backend/driver.py | 352 ++-- .../tsingmicro/backend/include/logger.h | 106 ++ .../tsingmicro/backend/logger_config.py | 135 ++ third_party/tsingmicro/backend/txda_tools.py | 33 +- third_party/tsingmicro/benchmark/benchmark.py | 1668 +++++++++++++++++ third_party/tsingmicro/crt/CMakeLists.txt | 10 +- .../tsingmicro/crt/include/Tx81/tx81.h | 8 +- third_party/tsingmicro/crt/lib/Tx81/abs.c | 1 + third_party/tsingmicro/crt/lib/Tx81/arith.c | 8 + .../crt/lib/Tx81/atomic_barrier_in.c | 1 + .../crt/lib/Tx81/atomic_barrier_out.c | 1 + .../tsingmicro/crt/lib/Tx81/bf16_fp16.c | 1 + .../tsingmicro/crt/lib/Tx81/bf16_int16.c | 1 + .../tsingmicro/crt/lib/Tx81/bf16_int32.c | 1 + .../tsingmicro/crt/lib/Tx81/bf16_int8.c | 1 + .../tsingmicro/crt/lib/Tx81/bf16_tf32.c | 1 + .../tsingmicro/crt/lib/Tx81/bilinear.c | 1 + third_party/tsingmicro/crt/lib/Tx81/bit2fp.c | 1 + third_party/tsingmicro/crt/lib/Tx81/concat.c | 1 + third_party/tsingmicro/crt/lib/Tx81/conv.c | 1 + third_party/tsingmicro/crt/lib/Tx81/cos.c | 1 + third_party/tsingmicro/crt/lib/Tx81/count.c | 1 + third_party/tsingmicro/crt/lib/Tx81/exp.c | 1 + third_party/tsingmicro/crt/lib/Tx81/explp.c | 1 + .../tsingmicro/crt/lib/Tx81/fp16_bf16.c | 1 + .../tsingmicro/crt/lib/Tx81/fp16_int16.c | 1 + .../tsingmicro/crt/lib/Tx81/fp16_int32.c | 1 + .../tsingmicro/crt/lib/Tx81/fp16_int8.c | 2 +- .../tsingmicro/crt/lib/Tx81/fp16_tf32.c | 1 + .../tsingmicro/crt/lib/Tx81/fp32_bf16.c | 1 + .../tsingmicro/crt/lib/Tx81/fp32_int16.c | 1 + .../tsingmicro/crt/lib/Tx81/fp32_int8.c | 1 + .../tsingmicro/crt/lib/Tx81/gelu_none.c | 1 + .../tsingmicro/crt/lib/Tx81/gelu_tanh.c | 1 + third_party/tsingmicro/crt/lib/Tx81/gemm.c | 1 + third_party/tsingmicro/crt/lib/Tx81/img2col.c | 1 + .../tsingmicro/crt/lib/Tx81/int16_bf16.c | 1 + .../tsingmicro/crt/lib/Tx81/int16_fp16.c | 1 + .../tsingmicro/crt/lib/Tx81/int16_fp32.c | 2 +- .../tsingmicro/crt/lib/Tx81/int16_tf32.c | 2 +- .../tsingmicro/crt/lib/Tx81/int32_bf16.c | 2 +- .../tsingmicro/crt/lib/Tx81/int32_fp16.c | 2 +- .../tsingmicro/crt/lib/Tx81/int32_tf32.c | 2 +- .../tsingmicro/crt/lib/Tx81/int8_bf16.c | 1 + .../tsingmicro/crt/lib/Tx81/int8_fp16.c | 1 + .../tsingmicro/crt/lib/Tx81/int8_tf32.c | 1 + .../tsingmicro/crt/lib/Tx81/leakyrelu.c | 2 +- third_party/tsingmicro/crt/lib/Tx81/ln.c | 2 +- third_party/tsingmicro/crt/lib/Tx81/log2.c | 2 +- third_party/tsingmicro/crt/lib/Tx81/logic.c | 4 +- third_party/tsingmicro/crt/lib/Tx81/lut16.c | 2 +- third_party/tsingmicro/crt/lib/Tx81/lut32.c | 2 +- .../tsingmicro/crt/lib/Tx81/mask_move.c | 1 + third_party/tsingmicro/crt/lib/Tx81/mirror.c | 2 +- .../tsingmicro/crt/lib/Tx81/mxfp_bf16.c | 3 + .../tsingmicro/crt/lib/Tx81/mxfp_scale_bf16.c | 1 + .../tsingmicro/crt/lib/Tx81/mxfp_scale_fp16.c | 1 + .../tsingmicro/crt/lib/Tx81/nchw2nhwc.c | 1 + .../tsingmicro/crt/lib/Tx81/nhwc2nchw.c | 1 + third_party/tsingmicro/crt/lib/Tx81/op_gelu.c | 2 + .../crt/lib/Tx81/op_reduce_mul_impl.c | 5 + third_party/tsingmicro/crt/lib/Tx81/pad.c | 1 + third_party/tsingmicro/crt/lib/Tx81/pow2.c | 1 + third_party/tsingmicro/crt/lib/Tx81/print.c | 2 +- third_party/tsingmicro/crt/lib/Tx81/randgen.c | 1 + .../tsingmicro/crt/lib/Tx81/relation.c | 22 + third_party/tsingmicro/crt/lib/Tx81/relu.c | 1 + .../tsingmicro/crt/lib/Tx81/rotate180.c | 1 + .../tsingmicro/crt/lib/Tx81/rotate270.c | 1 + .../tsingmicro/crt/lib/Tx81/rotate90.c | 1 + third_party/tsingmicro/crt/lib/Tx81/rsqrt.c | 1 + third_party/tsingmicro/crt/lib/Tx81/satrelu.c | 1 + third_party/tsingmicro/crt/lib/Tx81/sigmoid.c | 1 + third_party/tsingmicro/crt/lib/Tx81/sin.c | 1 + .../tsingmicro/crt/lib/Tx81/softplus.c | 1 + third_party/tsingmicro/crt/lib/Tx81/sqrt.c | 1 + third_party/tsingmicro/crt/lib/Tx81/tanh.c | 1 + .../tsingmicro/crt/lib/Tx81/tensornorm.c | 1 + .../tsingmicro/crt/lib/Tx81/tf32_bf16.c | 1 + .../tsingmicro/crt/lib/Tx81/tf32_fp16.c | 1 + .../tsingmicro/crt/lib/Tx81/tf32_fp32.c | 1 + .../tsingmicro/crt/lib/Tx81/tf32_int16.c | 1 + .../tsingmicro/crt/lib/Tx81/tf32_int32.c | 1 + .../tsingmicro/crt/lib/Tx81/tf32_int8.c | 1 + .../tsingmicro/crt/lib/Tx81/transpose.c | 1 + .../triton-shared/Analysis/MaskAnalysis.h | 12 + .../AnalysisStructured/PtrAnalysis.h | 17 + .../TritonArithToLinalg/ConversionPatterns.h | 89 +- .../include/utils/ReduceScanCommon.h | 64 +- .../tsingmicro/lib/Analysis/MaskAnalysis.cpp | 48 + .../lib/AnalysisStructured/PtrAnalysis.cpp | 284 ++- .../AllocateSharedMemoryPass.cpp | 4 +- .../LegalizeTensorFormLoops.cpp | 17 +- .../lib/Conversion/LinalgToMK/LinalgToMK.cpp | 185 +- .../Conversion/LinalgToMK/LinalgToMKPass.cpp | 2 +- .../lib/Conversion/MKToTx81/MKToTx81.cpp | 39 +- .../StructuredToMemref/StructuredToMemref.cpp | 944 +++++----- .../TritonToStructuredPass.cpp | 4 +- .../TritonToUnstructuredPass.cpp | 112 +- third_party/tsingmicro/scripts/READMD_DEV.md | 2 +- .../tsingmicro/scripts/base/base_run.sh | 7 +- .../tsingmicro/scripts/build_tsingmicro.sh | 10 +- .../tsingmicro/scripts/build_tx8_deps.sh | 34 +- .../scripts/ci/run_triton_flaggems_ci_test.sh | 324 ++++ .../tsingmicro/scripts/copy_config.conf | 4 +- .../tsingmicro/scripts/publish/README.md | 3 +- .../tsingmicro/scripts/publish/build_wheel.sh | 3 + .../tsingmicro/scripts/publish/publish.sh | 3 +- .../publish/run_flaggems_on_multicards.sh | 94 + 111 files changed, 3988 insertions(+), 828 deletions(-) create mode 100644 third_party/tsingmicro/backend/include/logger.h create mode 100644 third_party/tsingmicro/backend/logger_config.py create mode 100755 third_party/tsingmicro/benchmark/benchmark.py create mode 100755 third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh create mode 100755 third_party/tsingmicro/scripts/publish/run_flaggems_on_multicards.sh diff --git a/third_party/tsingmicro/CMakeLists.txt b/third_party/tsingmicro/CMakeLists.txt index cdfb69298e..bea4b5b64f 100644 --- a/third_party/tsingmicro/CMakeLists.txt +++ b/third_party/tsingmicro/CMakeLists.txt @@ -5,6 +5,7 @@ if(NOT DEFINED TX8_DEPS_ROOT) message(FATAL_ERROR "TX8_DEPS_ROOT environment variable is not defined") endif() endif() + # Enable ccache if available find_program(CCACHE_PROGRAM ccache) if(CCACHE_PROGRAM) @@ -18,9 +19,9 @@ if(NOT DEFINED USE_HOST_PROFILE) endif() endif() -if(NOT DEFINED USE_PROFILE) - if(DEFINED ENV{USE_PROFILE}) - set(USE_PROFILE $ENV{USE_PROFILE}) +if(NOT DEFINED ENABLE_PROFILING) + if(DEFINED ENV{ENABLE_PROFILING}) + set(ENABLE_PROFILING $ENV{ENABLE_PROFILING}) endif() endif() @@ -30,10 +31,13 @@ if(NOT DEFINED NO_INTRNISIC_RUN) endif() endif() +if(NOT DEFINED ENABLE_SYNCHRONOUS_INTRINSIC) + if(DEFINED ENV{ENABLE_SYNCHRONOUS_INTRINSIC}) + set(ENABLE_SYNCHRONOUS_INTRINSIC $ENV{ENABLE_SYNCHRONOUS_INTRINSIC}) + endif() +endif() + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -execute_process(COMMAND llvm-config --libs OUTPUT_VARIABLE LIBS) -execute_process(COMMAND llvm-config --system-libs OUTPUT_VARIABLE SYS_LIBS) -execute_process(COMMAND llvm-config --ldflags OUTPUT_VARIABLE LDF) set(XUANTIE_NAME Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2) set(INSTALL_TSINGMICRO_DIR ${CMAKE_INSTALL_PREFIX}/triton/backends/tsingmicro/) diff --git a/third_party/tsingmicro/backend/compiler.py b/third_party/tsingmicro/backend/compiler.py index 25a1ad8120..b6b4aa22bc 100644 --- a/third_party/tsingmicro/backend/compiler.py +++ b/third_party/tsingmicro/backend/compiler.py @@ -13,6 +13,9 @@ import functools from pathlib import Path from triton.backends.tsingmicro import txda_tools +from triton.backends.tsingmicro.logger_config import setup_logger + +logger = setup_logger("tsingmicro_launch") ir_index = -1 @@ -80,21 +83,20 @@ def compile_accelerator(src, metadata, o_path): dst_path ] - if txda_tools.is_use_profile(): - gcc_args.append("-lprofiler_riscv") + # if txda_tools.is_use_profile(): + # gcc_args.append("-lprofiler_riscv") txda_tools.runLoweringCmd(dst_path, gcc_args) with open(dst_path, 'rb') as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) txda_tools.dump_file_if_needed(cache_path, f"kernel_{ir_index}.so") - else: - print("cache_path: ", cache_path, flush=True) with open(cache_path, 'rb') as fd_out: so = fd_out.read() metadata["kernel_path"] = cache_path - txda_tools.record_key_v(last_ir, cache_path) + metadata["so_key"] = os.path.basename(os.path.dirname(cache_path)) + logger.debug(f"{last_ir}:{cache_path}") return so @@ -103,7 +105,7 @@ def _ttir_to_coreir(mod): ir_index = ir_index + 1 # Get Triton-MLIR as string ttir_code = str(mod) - txda_tools.record_log(f"get ttir:{txda_tools.calculate_str_md5(ttir_code.encode())}\n") + logger.debug(f"get ttir:{txda_tools.calculate_str_md5(ttir_code.encode())}") with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, f"tt_{ir_index}.mlir") dst_path = os.path.join(tmpdir, f"core_{ir_index}.mlir") @@ -247,19 +249,19 @@ def _txir_to_llir(mod, metadata): dest_name = ir_file break else: - print(f"error name {ir_file}: use customized ir like this : test_0.mlir, test_1.mlir") + logger.error(f"error name {ir_file}: use customized ir like this : test_0.mlir, test_1.mlir") if dest_name: - print(f"get CUSTOMIZED_IR path:{dest_name}") + logger.info(f"get CUSTOMIZED_IR path:{dest_name}") dump_path = os.getenv("TRITON_DUMP_PATH", "") if not dump_path: - print("TRITON_DUMP_PATH not find!") + logger.error("TRITON_DUMP_PATH not find!") return cust_ir = os.path.join(dump_path, dest_name) if os.path.exists(cust_ir): llvmir_path = cust_ir - print(f"!!!!!!!!!!!!!!!!!!using customized ir:{llvmir_path}") + logger.info(f"!!!!!!!!!!!!!!!!!!using customized ir:{llvmir_path}") # Get spm memory use metadata from mlir.ir import Context, Module @@ -276,22 +278,24 @@ def _txir_to_llir(mod, metadata): txda_tools.dump_ir_if_needed([llir_path]) dest_ir = llir_path - if txda_tools.is_use_profile(): - profile_ir_path = os.path.join(tmpdir, f"profile_ll_{ir_index}.ir") - trace_points = os.getenv("TRACE_POINTS", "") - profiler_path = txda_tools.get_tx8_profiler_path() - compile_args = [ - profiler_path, llir_path, f"--trace-points={trace_points}", "-index", f"{ir_index}", "-o", - profile_ir_path - ] - txda_tools.dump_cmd_if_needed(compile_args, "trace-points") - txda_tools.runLoweringCmd(profile_ir_path, compile_args) - txda_tools.dump_ir_if_needed([profile_ir_path]) - dest_ir = profile_ir_path + # if txda_tools.is_use_profile(): + # profile_ir_path = os.path.join(tmpdir, f"profile_ll_{ir_index}.ir") + # trace_points = os.getenv("TRACE_POINTS", "") + # profiler_path = txda_tools.get_tx8_profiler_path() + # compile_args = [ + # profiler_path, llir_path, + # f"--trace-points={trace_points}", + # "-index", f"{ir_index}", + # "-o", profile_ir_path + # ] + # txda_tools.dump_cmd_if_needed(compile_args, "trace-points") + # txda_tools.runLoweringCmd(profile_ir_path, compile_args) + # txda_tools.dump_ir_if_needed([profile_ir_path]) + # dest_ir = profile_ir_path global last_ir last_ir = os.path.basename(dest_ir) - txda_tools.record_log(f"last ir:{dest_ir}, {txda_tools.calculate_file_md5(dest_ir)}\n") + logger.debug(f"last ir:{dest_ir}, {txda_tools.calculate_file_md5(dest_ir)}") return Path(dest_ir).read_text() @@ -365,8 +369,8 @@ def _llir_to_bin(llir: str, metadata): if not sim_mode: compile_args.extend(["--target=riscv64-unknown-elf", "-march=rv64imfdc"]) - if txda_tools.is_use_profile(): - compile_args.append("-DUSE_PROFILE") + # if txda_tools.is_use_profile(): + # compile_args.append("-DENABLE_PROFILING") txda_tools.runLoweringCmd(dst_path, compile_args) diff --git a/third_party/tsingmicro/backend/driver.py b/third_party/tsingmicro/backend/driver.py index 590672310b..dc29f20eee 100644 --- a/third_party/tsingmicro/backend/driver.py +++ b/third_party/tsingmicro/backend/driver.py @@ -18,8 +18,12 @@ from triton.backends.compiler import GPUTarget from triton.backends.tsingmicro import txda_tools +from triton.backends.tsingmicro.logger_config import setup_logger, logger_to_custom_level_number, log_at_current_level + +logger = setup_logger("tsingmicro_launch") dirname = os.path.dirname(os.path.realpath(__file__)) +profiling_lib_dir = os.path.join(txda_tools.get_tx8_deps_path("profiling_tool"), "lib") if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")): scheme = sysconfig.get_default_scheme() py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] @@ -32,6 +36,7 @@ os.path.join(dirname, "include"), txda_tools.get_kuiper_path("include"), txda_tools.get_tx8_deps_path("include"), + os.path.join(txda_tools.get_tx8_deps_path("profiling_tool"), "include"), os.path.join(sysconfig.get_path('platlib'), "pybind11", "include"), os.path.join(sysconfig.get_path('platlib'), "torch", "include"), os.path.join(sysconfig.get_path('platlib'), "torch", "include", "torch", "csrc", "api", "include"), @@ -40,7 +45,7 @@ library_dirs = [ os.path.join(dirname, "lib"), txda_tools.get_kuiper_path("lib"), - txda_tools.get_tx8_deps_path("lib"), + txda_tools.get_tx8_deps_path("lib"), profiling_lib_dir, os.path.join(sysconfig.get_path('platlib'), "torch", "lib") ] libraries = ['hpgr', 'torch', 'torch_cpu', 'torch_python', 'c10'] @@ -71,13 +76,15 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-std=c++17", "-Wno-psabi", "-o", so] - if txda_tools.is_use_host_profile(): - cc_cmd += ["-DUSE_PROFILE"] + if txda_tools.is_use_profile(): + cc_cmd += ["-DENABLE_PROFILING"] if txda_tools.is_debug(): cc_cmd += ["-DCMAKE_BUILD_TYPE=Debug"] cc_cmd += [f'-l{lib}' for lib in libraries] - if txda_tools.is_use_host_profile(): - cc_cmd += ["-lprofiler_x86"] + if txda_tools.is_use_profile(): + profiling_flag = "tx8_profiling" + cc_cmd += [f"-l{profiling_flag}"] + cc_cmd += [f"-Wl,-rpath={profiling_lib_dir}"] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] txda_tools.runLoweringCmd(so, cc_cmd) @@ -91,6 +98,7 @@ def compile_native(src, name): key = hashlib.sha256(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) cache_path = cache.get_file(f"{fname}.so") + if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, f"{name}.cpp") @@ -102,8 +110,6 @@ def compile_native(src, name): with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{fname}.so", binary=True) txda_tools.dump_ir_if_needed([cache_path]) - else: - print("cache_path: ", cache_path, flush=True) spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec) @@ -172,7 +178,7 @@ def make_launcher(constants, signature, kernel_name, kernel_path): # Basic declarations. Arguments in triton kernel. arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr") args_format = ''.join([_format_of(ty) for ty in signature.values()]) - format = "issis" + "iiiOKOOOO" + args_format + format = "isssisi" + "iiiOKOOOO" + args_format args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' # Parameters to pass to the kernel function. Arguments in triton kernel except constants. @@ -401,13 +407,11 @@ def make_launcher(constants, signature, kernel_name, kernel_path): return f""" #include #include +#define PY_SSIZE_T_CLEAN #include +#include #include -#include #include -#include -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -//#include #include #include #include @@ -417,8 +421,16 @@ def make_launcher(constants, signature, kernel_name, kernel_path): #include #include +#include "logger.h" + #include "tx_runtime.h" -#include "profiler.h" + +#ifdef ENABLE_PROFILING + #include "tsm_profiler.h" + #define PROFILE_CALL(func, ...) func(__VA_ARGS__) +#else + #define PROFILE_CALL(func, ...) +#endif enum DATA_TYPE {{ SCALAR, @@ -444,18 +456,86 @@ def make_launcher(constants, signature, kernel_name, kernel_path): data.scalar = v; data_type = SCALAR; }} +}}; + +//============================== launch data begin ============================== +struct LaunchRes {{ + int res; + std::string log_buffer; + + LaunchRes() : res(0), log_buffer() {{}} + LaunchRes(int v, const char* n) : res(v), log_buffer(n ? n : "") {{}} +}}; + +typedef struct {{ + PyObject_HEAD + LaunchRes data; +}} LaunchResObj; + +// Deallocator +static void LaunchRes_dealloc(LaunchResObj* self) {{ + self->data.~LaunchRes(); + Py_TYPE(self)->tp_free((PyObject*)self); +}} +static PyObject* LaunchRes_repr(LaunchResObj* self) {{ + return PyUnicode_FromFormat("LaunchRes(res=%d, log_buffer='%s')", + self->data.res, + self->data.log_buffer.c_str()); +}} + +static PyMemberDef LaunchRes_members[] = {{ + {{"res", T_INT, offsetof(LaunchResObj, data.res), 0, "integer res"}}, + {{NULL}} }}; -// 释放的时候打印profiler数据 +static PyObject* LaunchRes_get_name(LaunchResObj* self, void* closure) {{ + return PyUnicode_FromString(self->data.log_buffer.c_str()); +}} + +static PyGetSetDef LaunchRes_getsetters[] = {{ + {{"log_buffer", (getter)LaunchRes_get_name, NULL, "log_buffer of the object", NULL}}, + {{NULL}} +}}; + +static PyTypeObject LaunchResObjType = {{ + PyVarObject_HEAD_INIT(NULL, 0) +}}; + +static PyObject* make_LaunchRes(int res, const char* log_buffer) {{ + LaunchResObj* obj = (LaunchResObj*)LaunchResObjType.tp_alloc(&LaunchResObjType, 0); + if (obj != NULL) {{ + new (&obj->data) LaunchRes(res, log_buffer); + }} + return (PyObject*)obj; +}} + + +//============================== launch data end ============================== + +simple_logger::Logger logger(simple_logger::ERROR); + auto g_guard = std::shared_ptr( nullptr, [](void*) {{ - PROFILE_CALL(printProfileAll); printf("guard release.\\n"); }} ); +struct Launch_args {{ + int gridX = 0; + int gridY = 0; + int gridZ = 0; + int device_id; + const char* so_key = nullptr; + const char* kernel_file = nullptr; + const char* kernel_fun_name = nullptr; + int is_dump_args = 0; + const char* dump_path = nullptr; + std::vector kargs; + txStream_t stream = nullptr; + int log_level = simple_logger::ERROR; +}}; static int read_bin_file(const char *file_name, char **content, size_t *length) {{ FILE *file; @@ -465,28 +545,28 @@ def make_launcher(constants, signature, kernel_name, kernel_path): file = fopen(file_name, "r"); if (file == NULL) {{ - printf("don't open file %s\\n", file_name); + logger.log(simple_logger::ERROR, "don't open file %s\\n", file_name); return -1; }} if (fseek(file, 0L, SEEK_END) != 0) {{ - printf("fseek to end failed \\n"); + logger.log(simple_logger::ERROR, "fseek to end failed \\n"); fclose(file); return -1; }} file_size = ftell(file); if (file_size == -1) {{ - printf("ftell file failed\\n"); + logger.log(simple_logger::ERROR, "ftell file failed\\n"); fclose(file); return -1; }} rewind(file); *length = file_size; *content = reinterpret_cast(malloc(sizeof(char) * (file_size + 1))); - printf("filename:%s length:%ld\\n", file_name, file_size); if (*content == NULL) {{ fclose(file); - printf("file content malloc error %s file_size:%ld\\n", file_name, file_size); + logger.log(simple_logger::ERROR, + "file content malloc error %s file_size:%ld\\n", file_name, file_size); return -1; }} @@ -497,69 +577,60 @@ def make_launcher(constants, signature, kernel_name, kernel_path): return 0; }} -void dump_kernel_args(int gridX, int gridY, int gridZ, - const std::string &kernel_file, const std::string &kernel_fun_name, - const std::vector &kargs, const std::string &dump_path) {{ +void dump_kernel_args(Launch_args &l_args) {{ - std::string dumpfile = dump_path + "/kernel_args.txt"; - std::ofstream outfile(dumpfile, std::ios::app); + std::ostringstream oss; - if (!outfile) {{ - printf("error can't open file:%s\\n", dumpfile.c_str()); - return; - }} + oss << "==============================" << std::endl; + oss << "kernel_file:"; + oss << l_args.kernel_file << ", "; + oss << "kernel_func:"; + oss << l_args.kernel_fun_name << ", "; - outfile << "==============================" << std::endl; - outfile << "kernel_file:"; - outfile << kernel_file << ", "; - outfile << "kernel_func:"; - outfile << kernel_fun_name << ", "; + oss << "gridX:"; + oss << l_args.gridX << ", "; + oss << "gridY:"; + oss << l_args.gridY << ", "; + oss << "gridZ:"; + oss << l_args.gridZ << ", "; - outfile << "gridX:"; - outfile << gridX << ", "; - outfile << "gridY:"; - outfile << gridY << ", "; - outfile << "gridZ:"; - outfile << gridZ << ", "; - - outfile << "blockX:"; - outfile << 1 << ", "; - outfile << "blockY:"; - outfile << 1 << ", "; - outfile << "blockZ:"; - outfile << 1 << ", "; - outfile << std::endl; + oss << "blockX:"; + oss << 1 << ", "; + oss << "blockY:"; + oss << 1 << ", "; + oss << "blockZ:"; + oss << 1 << ", "; + oss << std::endl; std::vector rtKargs; const char* str_args = "args"; int count = 0; - for (const KernelArg& karg : kargs) {{ - outfile << str_args << count++ << "_"; + for (const KernelArg& karg : l_args.kargs) {{ + oss << str_args << count++ << "_"; if (karg.data_type == POINT) {{ - outfile << "v:" << 1 << ", "; - outfile << str_args << count++ << "_" + oss << "v:" << 1 << ", "; + oss << str_args << count++ << "_" << "p:" << karg.data.ptr << ", "; }} else {{ - outfile << "v:" << std::hex << "0x" << karg.data.scalar << ", "; + oss << "v:" << std::hex << "0x" << karg.data.scalar << ", "; }} }} - outfile << str_args << count++ << "_v:" << std::hex << "0x" << gridX << ", "; - outfile << str_args << count++ << "_v:" << std::hex << "0x" << gridY << ", "; - outfile << str_args << count++ << "_v:" << std::hex << "0x" << gridZ << ", "; - outfile << str_args << count++ << "_v:" << std::hex << "0x" << 0 << ", "; - outfile << str_args << count++ << "_v:" << std::hex << "0x" << 0 << ", "; - outfile << str_args << count++ << "_v:" << std::hex << "0x" << 0 << ", "; - outfile << std::endl; - - outfile << "point size: "; - for (const KernelArg& karg : kargs) {{ + oss << str_args << count++ << "_v:" << std::hex << "0x" << l_args.gridX << ", "; + oss << str_args << count++ << "_v:" << std::hex << "0x" << l_args.gridY << ", "; + oss << str_args << count++ << "_v:" << std::hex << "0x" << l_args.gridZ << ", "; + oss << str_args << count++ << "_v:" << std::hex << "0x" << 0 << ", "; + oss << str_args << count++ << "_v:" << std::hex << "0x" << 0 << ", "; + oss << str_args << count++ << "_v:" << std::hex << "0x" << 0 << ", "; + oss << std::endl; + + oss << "point size: "; + for (const KernelArg& karg : l_args.kargs) {{ if (karg.data_type == POINT) {{ - outfile << karg.data.ptr << ":" << std::hex << "0x" << karg.size << ", "; + oss << karg.data.ptr << ":" << std::hex << "0x" << karg.size << ", "; }} }} - outfile << std::endl; - outfile.close(); + logger.log(simple_logger::DEBUG, "%s", oss.str().c_str()); }} static bool set_device_id(int device_id) {{ @@ -570,28 +641,23 @@ def make_launcher(constants, signature, kernel_name, kernel_path): return true; }} -static void _launch(int gridX, int gridY, int gridZ, - int device_id, std::string kernel_file, std::string kernel_fun_name, - int is_dump_args, std::string dump_path, txStream_t stream, - std::vector kargs) {{ - if (gridX*gridY*gridZ <= 0) {{ +static void _launch(Launch_args &l_args) {{ + if (l_args.gridX*l_args.gridY*l_args.gridZ <= 0) {{ return; // No work to do }} - if (!set_device_id(device_id)) {{ + if (!set_device_id(l_args.device_id)) {{ return; }} - - if (is_dump_args != 0) {{ - dump_kernel_args(gridX, gridY, gridZ, - kernel_file, kernel_fun_name, kargs, dump_path); + if (l_args.is_dump_args != 0) {{ + dump_kernel_args(l_args); }} // TODO::mv uint64_t kernel_len = 0; char* kernel_ptr = nullptr; - int ret = read_bin_file(kernel_file.c_str(), &kernel_ptr, &kernel_len); + int ret = read_bin_file(l_args.kernel_file, &kernel_ptr, &kernel_len); if (ret != 0 || kernel_ptr == nullptr) {{ PyErr_SetString(PyExc_RuntimeError, "Failed to read kernel so"); return; @@ -599,7 +665,7 @@ def make_launcher(constants, signature, kernel_name, kernel_path): // Allocate the device memory for all kernel arguments std::vector rtKargs; - for (KernelArg& karg : kargs) {{ + for (KernelArg& karg : l_args.kargs) {{ if (karg.data_type == POINT) {{ rtKargs.push_back(1); rtKargs.push_back((uint64_t)(karg.data.ptr)); @@ -607,25 +673,29 @@ def make_launcher(constants, signature, kernel_name, kernel_path): rtKargs.push_back((uint64_t)(karg.data.scalar)); }} }} - rtKargs.push_back(gridX); - rtKargs.push_back(gridY); - rtKargs.push_back(gridZ); + rtKargs.push_back(l_args.gridX); + rtKargs.push_back(l_args.gridY); + rtKargs.push_back(l_args.gridZ); rtKargs.push_back(0); rtKargs.push_back(0); rtKargs.push_back(0); - // txError_t txLaunchKernelGGL(const char *funcName, uint64_t elfAddr, uint64_t elfLen, dim3 gridDim, dim3 blockDim, - // void *kernelArg, uint32_t kernelArgLen, uint32_t sharedMemBytes, - // txStream_t tStream = nullptr); - uint32_t eventId = EVENT_INIT; - PROFILE_CALL(addOrderProfile, TIME_RUNTIME, TIME_LAUNCH_START, &eventId); - if (txLaunchKernelGGL(kernel_fun_name.c_str(), (uint64_t)kernel_ptr, kernel_len, - dim3({{(uint32_t)gridX, (uint32_t)gridY, (uint32_t)gridZ}}), dim3({{1u, 1u, 1u}}), - (void*)(&rtKargs[0]), rtKargs.size()*sizeof(uint64_t), 0, stream) != TX_SUCCESS){{ +#ifdef ENABLE_PROFILING + static int run_count = 0; + run_count++; + std::string profiling_key(l_args.so_key); + profiling_key.append("_").append(l_args.kernel_fun_name); + profiling_key.append("_").append(std::to_string(run_count)); +#endif + + PROFILE_CALL(TsmProcessProfData, l_args.device_id, profiling_key, PROF_START, 7); + if (txLaunchKernelGGL(l_args.kernel_fun_name, (uint64_t)kernel_ptr, kernel_len, + dim3({{(uint32_t)l_args.gridX, (uint32_t)l_args.gridY, (uint32_t)l_args.gridZ}}), dim3({{1u, 1u, 1u}}), + (void*)(&rtKargs[0]), rtKargs.size()*sizeof(uint64_t), 0, l_args.stream) != TX_SUCCESS){{ PyErr_SetString(PyExc_RuntimeError, "Failed to txLaunchKernelGGL"); }} - txStreamSynchronize(stream); - PROFILE_CALL(addOrderProfile, TIME_RUNTIME, TIME_LAUNCH_END, &eventId); + txStreamSynchronize(l_args.stream); + PROFILE_CALL(TsmProcessProfData, l_args.device_id, profiling_key, PROF_STOP, 0); }} // Structure to represent a device pointer @@ -645,7 +715,6 @@ def make_launcher(constants, signature, kernel_name, kernel_path): PyObject *element_size_method = PyObject_GetAttrString(obj, "element_size"); if (numel_method && element_size_method) {{ - printf("============= has numel_method and element_size_method ==============\\n"); fflush(stdout); PyObject *empty_tuple1 = PyTuple_New(0); PyObject *empty_tuple2 = PyTuple_New(0); @@ -662,7 +731,6 @@ def make_launcher(constants, signature, kernel_name, kernel_path): size_t element_size = (size_t)PyLong_AsLongLong(element_size_obj); size_t total_size = numel * element_size; - printf("============= numel size: %ld\\n", total_size); Py_DECREF(numel_obj); Py_DECREF(element_size_obj); return total_size; @@ -675,7 +743,6 @@ def make_launcher(constants, signature, kernel_name, kernel_path): if (element_size_method) Py_DECREF(element_size_method); }} - printf("==== zero size ========\\n"); fflush(stdout); return 0; // Return 0 if unable to determine size }} @@ -686,25 +753,22 @@ def make_launcher(constants, signature, kernel_name, kernel_path): ptr_info.valid = true; ptr_info.size = 0; // Initialize size - printf("idx: %d, PyObject : %p \\n", idx, obj); fflush(stdout); if (PyLong_Check(obj)) {{ ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); - printf("PyLong_AsLongLong %p\\n", ptr_info.dev_ptr); + logger.log(simple_logger::DEBUG, "PyLong_AsLongLong %p\\n", ptr_info.dev_ptr); return ptr_info; }} if (obj == Py_None) {{ // valid nullptr - - printf("Py_None\\n"); + logger.log(simple_logger::DEBUG, "Py_None\\n"); fflush(stdout); return ptr_info; }} PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); if(ptr){{ - printf("PyObject_GetAttrString\\n"); fflush(stdout); PyObject *empty_tuple = PyTuple_New(0); @@ -713,15 +777,13 @@ def make_launcher(constants, signature, kernel_name, kernel_path): Py_DECREF(ptr); if (!PyLong_Check(ret)) {{ PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + logger.log(simple_logger::ERROR, "data_ptr method of Pointer object must return 64-bit int\\n"); ptr_info.valid = false; - printf("data_ptr method of Pointer object must return 64-bit int\\n"); fflush(stdout); return ptr_info; }} ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); - printf("============= ptr_info.dev_ptr: %p\\n", ptr_info.dev_ptr); if(!ptr_info.dev_ptr) {{ - printf("ptr_info.dev_ptr null\\n"); fflush(stdout); return ptr_info; }} @@ -733,7 +795,9 @@ def make_launcher(constants, signature, kernel_name, kernel_path): return ptr_info; }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + std::string error_msg = "Pointer argument must be either uint64 or have data_ptr method\\n"; + PyErr_SetString(PyExc_TypeError, error_msg.c_str()); + logger.log(simple_logger::ERROR, error_msg.c_str()); ptr_info.valid = false; return ptr_info; }} @@ -741,7 +805,7 @@ def make_launcher(constants, signature, kernel_name, kernel_path): static size_t getTensorStorageSize(PyObject* tensor_obj) {{ const at::Tensor& tensor = THPVariable_Unpack(tensor_obj); - printf("========== total size ================: %ld\\n", tensor.storage().nbytes()); + logger.log(simple_logger::DEBUG, "========== total size ================: %ld\\n", tensor.storage().nbytes()); return tensor.storage().nbytes(); }} @@ -749,12 +813,11 @@ def make_launcher(constants, signature, kernel_name, kernel_path): static void* extractTensor(PyObject* tensor_obj) {{ const at::Tensor& tensor = THPVariable_Unpack(tensor_obj); torch::Tensor contiguous_tensor = tensor.contiguous(); - printf("========== ptr ================: %p\\n", contiguous_tensor.data_ptr()); + logger.log(simple_logger::DEBUG, "========== ptr ================: %p\\n", contiguous_tensor.data_ptr()); return contiguous_tensor.data_ptr(); }} static PyObject* release(PyObject* self, PyObject* args) {{ - PROFILE_CALL(printProfileAll); Py_RETURN_NONE; }} @@ -771,43 +834,46 @@ def make_launcher(constants, signature, kernel_name, kernel_path): const char* kernel_file = "base_kernel_path"; const char* kernel_fun_name = "base_kernel_func_name"; const char* dump_path = ""; + const char* so_key = ""; int is_dump_args = 0; uint32_t sharedMemBytes = 0; int device_id = 0; txStream_t stream = nullptr; + int log_level = simple_logger::ERROR; + + Launch_args l_args; // Define the actual kernel arguments {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} // Init kernel arguments from python side - if(!PyArg_ParseTuple(args, \"{format}\", &device_id, &kernel_file, - &kernel_fun_name, &is_dump_args, &dump_path, - &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + if(!PyArg_ParseTuple(args, \"{format}\", &l_args.device_id, &l_args.so_key, &l_args.kernel_file, + &l_args.kernel_fun_name, &l_args.is_dump_args, &l_args.dump_path, &l_args.log_level, + &l_args.gridX, &l_args.gridY, &l_args.gridZ, &py_obj_stream, &pKrnl, &kernel_metadata, &launch_metadata, &launch_enter_hook, &launch_exit_hook {args_list})) {{ - return NULL; + return make_LaunchRes(-1, ""); }} - // Construct a data kernel arguments list data structure - std::vector kargs; - //{' '.join([f"kargs.emplace_back(_arg{i}, PyObject_Size(_arg{i})*4);" if ty[0]=="*" else f"kargs.emplace_back(*(uint64_t*)&_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} - // {' '.join([f"kargs.emplace_back(extractTensor(_arg{i}), getTensorStorageSize(_arg{i}));" - if ty[0]=="*" else f"kargs.emplace_back(*(uint64_t*)&_arg{i}, sizeof(_arg{i}));" + logger.setLogLevel((simple_logger::LogLevel)l_args.log_level); + + //{' '.join([f"l_args.kargs.emplace_back(_arg{i}, PyObject_Size(_arg{i})*4);" if ty[0]=="*" else f"l_args.kargs.emplace_back(*(uint64_t*)&_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} + // {' '.join([f"l_args.kargs.emplace_back(extractTensor(_arg{i}), getTensorStorageSize(_arg{i}));" + if ty[0]=="*" else f"l_args.kargs.emplace_back(*(uint64_t*)&_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items() if ty != "constexpr"])}; - {' '.join([f"kargs.emplace_back(ptr_info{i}.dev_ptr, ptr_info{i}.size);" - if ty[0]=="*" else f"kargs.emplace_back(*(uint64_t*)&_arg{i}, sizeof(_arg{i}));" + {' '.join([f"l_args.kargs.emplace_back(ptr_info{i}.dev_ptr, ptr_info{i}.size);" + if ty[0]=="*" else f"l_args.kargs.emplace_back(*(uint64_t*)&_arg{i}, sizeof(_arg{i}));" for i, ty in signature.items() if ty != "constexpr"])} // Launch the kernel - _launch(gridX, gridY, gridZ, device_id, std::string(kernel_file), - std::string(kernel_fun_name), is_dump_args, std::string(dump_path), stream, kargs); + _launch(l_args); if (PyErr_Occurred()) {{ - return NULL; + return make_LaunchRes(-1, ""); }} // Call the exit hook if provided @@ -816,12 +882,10 @@ def make_launcher(constants, signature, kernel_name, kernel_path): PyObject* ret = PyObject_CallObject(launch_exit_hook, hook_args); Py_DECREF(hook_args); if (!ret) - return NULL; + return make_LaunchRes(-1, ""); }} - // Return None to Python - Py_INCREF(Py_None); - return Py_None; + return make_LaunchRes(0, ""); }} // Python module method definitions @@ -848,6 +912,27 @@ def make_launcher(constants, signature, kernel_name, kernel_path): }} PyModule_AddFunctions(m, ModuleMethods); + + LaunchResObjType.tp_name = "__triton_launcher.LaunchRes"; + LaunchResObjType.tp_doc = "Custom LaunchRes objects"; + LaunchResObjType.tp_basicsize = sizeof(LaunchResObj); + LaunchResObjType.tp_itemsize = 0; + LaunchResObjType.tp_flags = Py_TPFLAGS_DEFAULT; + LaunchResObjType.tp_new = PyType_GenericNew; + LaunchResObjType.tp_dealloc = (destructor)LaunchRes_dealloc; + LaunchResObjType.tp_repr = (reprfunc)LaunchRes_repr; + LaunchResObjType.tp_members = LaunchRes_members; + LaunchResObjType.tp_getset = LaunchRes_getsetters; + + if (PyType_Ready(&LaunchResObjType) < 0) + return NULL; + + Py_INCREF(&LaunchResObjType); + if (PyModule_AddObject(m, "LaunchRes", (PyObject*)&LaunchResObjType) < 0) {{ + Py_DECREF(&LaunchResObjType); + Py_DECREF(m); + return NULL; + }} return m; }} """ @@ -904,12 +989,10 @@ def __init__(self, src, metadata): signature = {cst_key(key): value for key, value in src.signature.items()} # Compiler runtime kernel launcher source code - kernel_path = metadata.kernel_path - print("==== kernel_path: ", kernel_path) - launcher_src = make_launcher(constants, signature, src.fn.__name__, kernel_path) + self.metadata = metadata + launcher_src = make_launcher(constants, signature, src.fn.__name__, metadata.kernel_path) mod = compile_native(launcher_src, "__triton_launcher") self.launch = mod.launch - self.kernel_path = kernel_path self.func_name = src.fn.__name__ def __call__(self, *args, **kwargs): @@ -920,8 +1003,16 @@ def __call__(self, *args, **kwargs): # 9~N: Actual triton kernel args. import torch device_id = torch.txda.current_device() - self.launch(device_id, self.kernel_path, self.func_name, txda_tools.is_dump_args_profile(), - txda_tools.get_dump_dir(), *args, **kwargs) + logger.info(f"{self.func_name} launch card:{device_id} begin") + log_level = logger_to_custom_level_number(logger) + launchRes = self.launch(device_id, self.metadata.so_key, self.metadata.kernel_path, self.func_name, + txda_tools.is_dump_args_profile(), txda_tools.get_dump_dir(), log_level, *args, + **kwargs) + + if launchRes.res != 0: + logger.error(f"launch error code:{launchRes.res}") + + logger.info(f"{self.func_name} launch card:{device_id} end") class TXDADriver(GPUDriver): @@ -929,7 +1020,6 @@ class TXDADriver(GPUDriver): def __init__(self): import torch super().__init__() - print("============= call TXDADriver test") if (os.getenv("USE_SIM_MODE", "0").lower() in ("1", "true", "yes")): self.utils = SimulatorUtils() else: diff --git a/third_party/tsingmicro/backend/include/logger.h b/third_party/tsingmicro/backend/include/logger.h new file mode 100644 index 0000000000..ab6e05b873 --- /dev/null +++ b/third_party/tsingmicro/backend/include/logger.h @@ -0,0 +1,106 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace simple_logger { + +enum LogLevel { DEBUG = 0, INFO = 1, WARN = 2, ERROR = 3 }; + +class Logger { +public: + explicit Logger(LogLevel level = INFO) : current_level_(level) {} + + void setLogLevel(LogLevel level) { current_level_ = level; } + + void log(LogLevel level, const char *format, ...) { + if (level < current_level_) { + return; + } + + va_list args; + va_start(args, format); + int msg_len = std::vsnprintf(nullptr, 0, format, args); + va_end(args); + + if (msg_len <= 0) { + const char *err = "<>"; + msg_len = static_cast(std::strlen(err)); + } + + auto now = std::chrono::system_clock::now(); + auto now_ms = std::chrono::time_point_cast(now); + auto ms = now_ms.time_since_epoch() % 1000; + + std::time_t now_time_t = std::chrono::system_clock::to_time_t(now_ms); + std::tm tm{}; + localtime_r(&now_time_t, &tm); + + char time_buf[64]; + std::strftime(time_buf, sizeof(time_buf), "%Y%m%d %H:%M:%S", &tm); + size_t time_len = std::strlen(time_buf); + std::snprintf(time_buf + time_len, sizeof(time_buf) - time_len, ".%03d", + static_cast(ms.count())); + time_len += 4; + + const char *level_prefix; + switch (level) { + case DEBUG: + level_prefix = "[DEBUG] "; + break; + case INFO: + level_prefix = "[INFO ] "; + break; + case WARN: + level_prefix = "[WARN ] "; + break; + case ERROR: + level_prefix = "[ERROR] "; + break; + default: + level_prefix = "[?????] "; + } + size_t level_len = std::strlen(level_prefix); + + size_t total_add_len = + 1 + time_len + 2 + level_len + + static_cast( + msg_len > 0 ? msg_len + : static_cast(std::strlen("<>"))) + + 1; + + std::string buffer; + buffer.resize(total_add_len); + char *out = &buffer[0]; + + char *p = out; + *p++ = '['; + std::memcpy(p, time_buf, time_len); + p += time_len; + *p++ = ']'; + + std::memcpy(p, level_prefix, level_len); + p += level_len; + + if (msg_len > 0) { + va_start(args, format); + std::vsnprintf(p, static_cast(msg_len) + 1, format, args); + va_end(args); + p += msg_len; + } else { + const char *err = "<>"; + std::memcpy(p, err, std::strlen(err)); + p += std::strlen(err); + } + + *p++ = '\n'; + printf("%s", buffer.c_str()); + } + +private: + LogLevel current_level_; +}; + +} // namespace simple_logger diff --git a/third_party/tsingmicro/backend/logger_config.py b/third_party/tsingmicro/backend/logger_config.py new file mode 100644 index 0000000000..3d474c1525 --- /dev/null +++ b/third_party/tsingmicro/backend/logger_config.py @@ -0,0 +1,135 @@ +# logger_config.py +import logging +import os +import sys + +# Standard mapping: custom number (0~4) -> logging constant +CUSTOM_NUMBER_TO_LOGGING = { + 0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING, 3: logging.ERROR, 4: logging.CRITICAL +} + +# Reverse for validation +LOGGING_TO_CUSTOM_NUMBER = {v: k for k, v in CUSTOM_NUMBER_TO_LOGGING.items()} + +# Standard level names for validation +STANDARD_LEVEL_NAMES = { + name.upper(): getattr(logging, name) + for name in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] +} + + +def get_log_level_from_env(env_var='TX_LOG_LEVEL', default='info'): + """ + Read log level from environment variable. + Supports: + - String: 'DEBUG', 'info', 'WARNING', etc. (case-insensitive) + - Number: '0', '1', '2', '3', '4' + Returns a standard logging level integer (e.g., logging.INFO = 20). + """ + raw_value = os.getenv(env_var, default).strip() + if not raw_value: + raw_value = default + + # Try to interpret as integer (custom 0~4) + try: + num = int(raw_value) + if num in CUSTOM_NUMBER_TO_LOGGING: + return CUSTOM_NUMBER_TO_LOGGING[num] + else: + print(f"Warning: Invalid numeric log level '{num}'. Must be 0~4. Using default '{default}'.") + except ValueError: + # Not an integer, treat as string + level_str = raw_value.upper() + if level_str in STANDARD_LEVEL_NAMES: + return STANDARD_LEVEL_NAMES[level_str] + else: + print(f"Warning: Invalid log level string '{raw_value}'. " + f"Expected one of {list(STANDARD_LEVEL_NAMES.keys())} or 0~4. Using default '{default}'.") + + # Fallback to default + fallback_level_str = str(default).upper() + if fallback_level_str in STANDARD_LEVEL_NAMES: + return STANDARD_LEVEL_NAMES[fallback_level_str] + elif default.isdigit(): + fallback_num = int(default) + if fallback_num in CUSTOM_NUMBER_TO_LOGGING: + return CUSTOM_NUMBER_TO_LOGGING[fallback_num] + # Final fallback + return logging.INFO + + +# Custom log level mapping: standard level names to custom integers (0~4) +CUSTOM_LEVEL_MAP = {'DEBUG': 0, 'INFO': 1, 'WARNING': 2, 'ERROR': 3, 'CRITICAL': 4} + + +def log_level_name_to_custom_number(level_name: str) -> int: + """ + Convert a standard log level name (e.g., 'INFO') to a custom integer (0~4). + Case-insensitive. + """ + level_name = level_name.upper() + if level_name not in CUSTOM_LEVEL_MAP: + raise ValueError(f"Unsupported log level: {level_name}") + return CUSTOM_LEVEL_MAP[level_name] + + +def logger_to_custom_level_number(logger) -> int: + """ + Get the effective log level of the given logger and convert it to a custom integer (0~4). + """ + effective_level = logger.getEffectiveLevel() + level_name = logging.getLevelName(effective_level) + + # Handle non-standard or unrecognized levels + if not isinstance(level_name, str) or level_name.startswith("Level "): + raise ValueError(f"Unrecognized log level value: {effective_level}") + + return log_level_name_to_custom_number(level_name) + + +def log_at_current_level(logger, message): + current_level = logger.getEffectiveLevel() + if current_level <= logging.DEBUG: + logger.debug(message) + elif current_level <= logging.INFO: + logger.info(message) + elif current_level <= logging.WARNING: + logger.warning(message) + elif current_level <= logging.ERROR: + logger.error(message) + else: + logger.critical(message) + + +def setup_logger(name='tsingmicro'): + """ + Set up and return a unified logger instance. + Log level is controlled by the LOG_LEVEL environment variable (default: INFO). + Logs are output to both console and file. + """ + log_file = f"{name}.log" + logger = logging.getLogger(name) + + if not logger.handlers: + log_level = get_log_level_from_env() + + # Set logger to lowest level; actual filtering is done by handlers + logger.setLevel(log_level) + + formatter = logging.Formatter(fmt='[%(asctime)s.%(msecs)03d][%(levelname)s]%(name)s:%(message)s', + datefmt='%Y%m%d %H:%M:%S') + + # File handler + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger diff --git a/third_party/tsingmicro/backend/txda_tools.py b/third_party/tsingmicro/backend/txda_tools.py index 31c35b7265..b4d87a4fb8 100644 --- a/third_party/tsingmicro/backend/txda_tools.py +++ b/third_party/tsingmicro/backend/txda_tools.py @@ -4,9 +4,12 @@ import hashlib from posixpath import dirname +from triton.backends.tsingmicro.logger_config import setup_logger + +logger = setup_logger("tsingmicro_launch") + _dump_dir_cache = None dump_cmd_count = 0 -kernel_run_file = "kernel_run.log" def _get_dump_env_path(): @@ -32,19 +35,12 @@ def get_dump_dir(): if not os.path.exists(full_path): os.makedirs(full_path) _dump_dir_cache = full_path # 缓存结果 - print(f"mkdir: {full_path}") + logger.debug(f"make dump dir:{full_path}") break index += 1 return _dump_dir_cache -def record_log(string: str): - record_file = os.path.join(get_dump_dir(), kernel_run_file) - with open(record_file, 'a', encoding='utf-8') as record_file: - record_file.write(string) - print(string) - - def runLoweringCmd(destFile: str, args: list): isAlwaysCompile = os.getenv("TRITON_ALWAYS_COMPILE", "0").lower() in ("1", "true", "yes") if isAlwaysCompile or not os.path.exists(destFile): @@ -53,18 +49,11 @@ def runLoweringCmd(destFile: str, args: list): else: subprocess.check_call(args, stdout=subprocess.DEVNULL) else: - print(f"Skip lowering {destFile}") + logger.debug(f"Skip lowering {destFile}") def is_use_profile(): - return os.getenv("USE_PROFILE", "0").lower() in ("1", "true", "yes") - - -def is_use_host_profile(): - if is_use_profile(): - return True - else: - return os.getenv("USE_HOST_PROFILE", "0").lower() in ("1", "true", "yes") + return os.getenv("ENABLE_PROFILING", "0").lower() in ("1", "true", "yes") def dump_ir_if_needed(files): @@ -137,10 +126,6 @@ def is_debug(): return debug_value in ["ON", "TRUE", "1", "YES"] -def record_key_v(key: str, v: str): - record_log(f"{key}:{v}\n") - - def calculate_str_md5(string: str): str_hash = hashlib.md5(string).hexdigest() return str_hash @@ -150,7 +135,3 @@ def calculate_file_md5(file_path): with open(file_path, 'rb') as f: file_bytes = f.read() return calculate_str_md5(file_bytes) - - -def record_file_hash(file: str): - record_log(f"{file}:{calculate_file_md5(file)}\n") diff --git a/third_party/tsingmicro/benchmark/benchmark.py b/third_party/tsingmicro/benchmark/benchmark.py new file mode 100755 index 0000000000..9bd144db59 --- /dev/null +++ b/third_party/tsingmicro/benchmark/benchmark.py @@ -0,0 +1,1668 @@ +#!/usr/bin/env python3 +""" +使用triton的do_bench来测量性能,支持各种GPU设备 + +使用方法: + export PYTHONPATH=/your/workspace/FlagGems/src + bash third_party/tsingmicro/scripts/run_tsingmicro.sh pytest test_abs_cuda_time.py +""" + +import pytest +import torch +import flag_gems + +# 导入triton的do_bench,triton会自动处理不同设备的兼容性 +try: + import triton + _do_bench = triton.testing.do_bench +except ImportError: + raise ImportError("triton不可用,请安装triton以支持性能测试") + +# 测试配置常量 +BASE_SHAPES = [(256, 256), (4096, 4096), (16384, 16384)] +# BASE_SHAPES = [(256, 256)] +DTYPE = torch.float16 +WARMUP = 1 +REPETITION = 3 + +# 索引和形状限制常量(避免OOM) +MAX_INDEX_SIZE = 64 +MAX_BATCH_SIZE = 8 +MAX_SEQ_LEN = 256 +MAX_KRON_1D_SIZE = 32 +MAX_KRON_2D_SIZE = 16 +MAX_KRON_ND_SIZE = 8 +MAX_ISIN_TEST_ELEMENTS = 100 + +# 所有算子列表(从 op_list.md 提取,按字母顺序排序) +OP_LIST = [ + 'abs', + 'add', + 'all', + 'amax', + 'angle', + 'any', + 'argmax', + 'argmin', + 'arange', + 'bitwise_and', + 'bitwise_not', + 'bitwise_or', + 'cat', + 'concat_and_cache_mla', + 'contiguous', + 'cos', + 'count_nonzero', + 'cross_entropy_loss', + 'cumsum', + 'diag_embed', + 'diagonal', + 'div', + 'dot', + 'dropout', + 'elu', + 'embedding', + 'eq', + 'erf', + 'exp', + 'eye', + 'fill', + 'flash_attention_forward', + 'flash_mla', + 'flip', + 'floor_divide', + 'full', + 'full_like', + 'fused_add_rms_norm', + 'gather', + 'gelu', + 'gelu_and_mul', + 'ge', + 'glu', + 'gt', + 'hstack', + 'index', + 'index_put', + 'index_select', + 'isin', + 'isfinite', + 'isinf', + 'isclose', + 'isnan', + 'layer_norm', + 'le', + 'linspace', + 'log', + 'log_softmax', + 'logical_and', + 'logical_not', + 'logical_or', + 'logical_xor', + 'lt', + 'masked_fill', + 'masked_select', + 'max', + 'maximum', + 'mean', + 'min', + 'minimum', + 'mm', + 'mse_loss', + 'mul', + 'multinomial', + 'mv', + 'matmul', + 'nan_to_num', + 'ne', + 'neg', + 'nll_loss', + 'nonzero', + 'normal', + 'ones', + 'ones_like', + 'outer', + 'pad', + 'pow', + 'prod', + 'rand', + 'rand_like', + 'randn', + 'randn_like', + 'reciprocal', + 'relu', + 'remainder', + 'repeat_interleave', + 'reshape_and_cache', + 'reshape_and_cache_flash', + 'resolve_conj', + 'resolve_neg', + 'rms_norm', + 'rsqrt', + 'rsub', + 'scaled_dot_product_attention', + 'scatter', + 'select', + 'sigmoid', + 'silu', + 'silu_and_mul', + 'slice_scatter', + 'softmax', + 'sort', + 'stack', + 'sub', + 'sum', + 'tanh', + 'threshold', + 'tile', + 'to', + 'topk', + 'triu', + 'unique', + 'vector_norm', + 'vdot', + 'vstack', + 'where', + 'zeros', + 'zeros_like', + #'batch_norm','bitwise_xor', 'conv1d', 'conv2d', 'conv_depthwise2d','cummax', 'cummin', 'diag', 'group_norm', + #, 'index_add','kron', 'lerp', 'log_sigmoid', 'polar', 'quantile', 'randperm', 'var_mean', +] + +# 全局标志,确保只启用一次 +_flag_gems_enabled = False + + +def _get_do_bench(): + """获取do_bench函数,triton会自动处理不同设备的兼容性""" + return _do_bench + + +def _format_shape_str(op_name, inputs, config, shape): + """格式化shape字符串用于显示和报告""" + if config['type'] == 'matrix': + if op_name == 'addmm': + return f"bias{inputs[0].shape},mat1{inputs[1].shape},mat2{inputs[2].shape}" + elif op_name == 'bmm': + return f"mat1{inputs[0].shape},mat2{inputs[1].shape}" + elif op_name == 'mv': + return f"mat{inputs[0].shape},vec{inputs[1].shape}" + else: + return f"mat1{inputs[0].shape},mat2{inputs[1].shape}" + elif config['type'] == 'ternary': + if op_name == 'lerp': + weight_str = str(inputs[2]) if isinstance(inputs[2], (int, float)) else str(inputs[2].shape) + return f"inp{inputs[0].shape},end{inputs[1].shape},weight{weight_str}" + else: + return f"inp1{inputs[0].shape},inp2{inputs[1].shape},inp3{inputs[2].shape}" + elif config['type'] == 'binary': + return f"{inputs[0].shape}" + elif config['type'] == 'multi_input': + return f"inputs{len(inputs[0])}x{inputs[0][0].shape if inputs[0] else '()'}" + elif config['type'] == 'special_constructor': + if op_name == 'arange': + end = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else ( + shape if isinstance(shape, int) else 256) + return f"arange(0, {end})" + elif op_name == 'linspace': + steps = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else ( + shape if isinstance(shape, int) else 256) + return f"linspace(0, 1, {steps})" + else: + return str(shape) + else: + return str(inputs.shape if hasattr(inputs, 'shape') else shape) + + +def _store_result(op_name, shape_str, config, avg_time_ms, avg_time_us, elapsed_time_ms, error=None): + """存储测试结果到pytest._op_perf_results""" + if not hasattr(pytest, '_op_perf_results'): + pytest._op_perf_results = [] + + result = { + 'op': op_name, + 'shape': shape_str, + 'config': str(config.get('extra_args', {})) if config.get('extra_args') else '', + 'dtype': str(DTYPE).split('.')[-1], + 'avg_time_ms': avg_time_ms, + 'avg_time_us': avg_time_us, + 'elapsed_time_ms': elapsed_time_ms, + 'type': config['type'], + } + + if error: + result['error'] = str(error) + + pytest._op_perf_results.append(result) + + +def _ensure_flag_gems_enabled(): + """确保FlagGems已启用,避免重复注册""" + global _flag_gems_enabled + if not _flag_gems_enabled: + try: + flag_gems.enable() + _flag_gems_enabled = True + except RuntimeError as e: + # 如果已经启用,忽略重复注册的错误 + if "already a kernel registered" not in str(e): + raise + _flag_gems_enabled = True + + +def parse_op_list(): + """解析并返回所有算子列表""" + return OP_LIST + + +def get_op_config(op_name): + """ + 为每个算子返回测试配置(shape适配和调用方式) + + Args: + op_name: 算子名称 + + Returns: + dict: 包含type、shapes、extra_args、dtype等配置的字典 + """ + op_name_lower = op_name.lower() + + # 真正需要跳过的算子(不适合性能测试) + skip_ops = { + 'fused_add_rms_norm', # 复合算子 + 'concat_and_cache_mla', # 特殊缓存算子 + 'reshape_and_cache', # 特殊缓存算子 + 'reshape_and_cache_flash', # 特殊缓存算子 + 'flash_attention_forward', # 需要很多参数 + 'flash_mla', # 需要很多参数 + 'scaled_dot_product_attention', # 需要很多参数 + 'nonzero', # 返回indices,不适合性能测试 + 'unique', # 返回indices,不适合性能测试 + 'to', # dtype转换,不适合性能测试 + 'contiguous', # 内存操作,不适合性能测试 + 'resolve_neg', # 特殊算子 + 'resolve_conj', # 特殊算子 + 'gelu_and_mul', # 复合算子 + 'silu_and_mul', # 复合算子 + } + + # 需要固定配置的特殊算子(参考FlagGems/tests中的测试用例) + special_fixed_config_ops = { + 'batch_norm', # 需要 weight, bias, running_mean, running_var 等 + 'layer_norm', # 需要 normalized_shape, weight, bias 等 + 'group_norm', # 需要 num_groups, weight, bias 等 + 'rms_norm', # 可能需要特殊参数 + 'cross_entropy_loss', # 需要 target + 'nll_loss', # 需要 target + 'mse_loss', # 需要 target + 'embedding', # 需要 indices + 'gather', # 需要 indices + 'index', # 需要 indices + 'index_add', # 需要 indices + 'index_put', # 需要 indices + 'index_select', # 需要 index + 'scatter', # 需要 index + 'slice_scatter', # 需要 index + 'multinomial', # 需要 num_samples + 'topk', # 需要 k + 'quantile', # 需要 q + 'where', # 需要 condition + 'masked_fill', # 需要 mask + 'masked_select', # 需要 mask + 'select', # 需要 dim 和 index + 'diagonal', # 需要 offset, dim1, dim2 + 'diag', # 需要 offset + 'diag_embed', # 需要 offset, dim1, dim2 + 'pad', # 需要 pad 参数 + 'tile', # 需要 dims 参数 + 'repeat_interleave', # 需要 repeats 参数 + 'kron', # 需要两个输入 + 'outer', # 需要两个输入 + 'vdot', # 需要两个输入 + 'dot', # 需要两个输入 + 'polar', # 需要两个输入 + 'atan2', # 需要两个输入 + 'hypot', # 需要两个输入 + 'fmod', # 需要两个输入 + 'isin', # 需要test_elements参数 + 'conv1d', # 需要 weight, bias, stride, padding 等 + 'conv2d', # 需要 weight, bias, stride, padding 等 + 'conv_depthwise2d', # 需要 weight, bias, stride, padding 等 + 'fill', # 需要 value 参数 + } + + if op_name_lower in skip_ops: + return {'type': 'skip', 'reason': '不适合性能测试'} + + # 初始化config字典 + config = { + 'type': 'unary', # default + 'shapes': BASE_SHAPES, + 'extra_args': {}, + 'call_func': None, + } + + if op_name_lower in special_fixed_config_ops: + config['type'] = 'special_fixed_config' + # 为每个特殊算子设置固定配置 + if op_name_lower == 'batch_norm': + # batch_norm: (N, C, H, W) -> (N, C, H, W) + config['shapes'] = [(16, 32, 32, 32), (32, 64, 64, 64), (64, 128, 128, 128)] + config['extra_args'] = {'eps': 1e-5, 'momentum': 0.1, 'training': True} + elif op_name_lower == 'layer_norm': + # layer_norm: (N, C, H, W) -> (N, C, H, W) + config['shapes'] = [(16, 32, 32, 32), (32, 64, 64, 64), (64, 128, 128, 128)] + config['extra_args'] = {'eps': 1e-5} + elif op_name_lower == 'group_norm': + # group_norm: (N, C, H, W) -> (N, C, H, W) + config['shapes'] = [(16, 32, 32, 32), (32, 64, 64, 64), (64, 128, 128, 128)] + config['extra_args'] = {'num_groups': 8, 'eps': 1e-5} + elif op_name_lower == 'rms_norm': + # rms_norm: (N, C) -> (N, C) + config['shapes'] = [(256, 512), (4096, 4096), (16384, 16384)] + config['extra_args'] = {'eps': 1e-5} + elif op_name_lower == 'conv1d': + # conv1d: (N, C, L) -> (N, C_out, L_out) + config['shapes'] = [(16, 32, 128), (32, 64, 256), (64, 128, 512)] + config['extra_args'] = {'stride': 1, 'padding': 1, 'dilation': 1} + elif op_name_lower == 'conv2d': + # conv2d: (N, C, H, W) -> (N, C_out, H_out, W_out) + config['shapes'] = [(16, 32, 32, 32), (32, 64, 64, 64), (64, 128, 128, 128)] + config['extra_args'] = {'stride': 1, 'padding': 1, 'dilation': 1, 'groups': 1} + elif op_name_lower == 'conv_depthwise2d': + # conv_depthwise2d: (N, C, H, W) -> (N, C, H_out, W_out) + config['shapes'] = [(16, 32, 32, 32), (32, 64, 64, 64), (64, 128, 128, 128)] + config['extra_args'] = {'stride': 1, 'padding': 1, 'dilation': 1} + elif op_name_lower == 'embedding': + # embedding: (num_embeddings, embedding_dim), indices -> (indices_shape, embedding_dim) + config['shapes'] = [(256, 512), (4096, 4096), (16384, 16384)] + config['extra_args'] = {'num_embeddings': 10000, 'embedding_dim': 512} + elif op_name_lower == 'gather': + # gather: input, dim, index -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'dim': 0} + elif op_name_lower == 'index_select': + # index_select: input, dim, index -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'dim': 0} + elif op_name_lower == 'scatter': + # scatter: input, dim, index, src -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'dim': 0} + elif op_name_lower == 'topk': + # topk: input, k -> (values, indices) + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'k': 10, 'dim': -1} + elif op_name_lower == 'multinomial': + # multinomial: input, num_samples -> indices + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'num_samples': 10} + elif op_name_lower == 'quantile': + # quantile: input, q -> output + # quantile不支持float16,需要使用float32或float64 + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'q': 0.5, 'dim': -1} + config['dtype'] = torch.float32 + elif op_name_lower == 'where': + # where: condition, x, y -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + elif op_name_lower == 'masked_fill': + # masked_fill: input, mask, value -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'value': 0.0} + elif op_name_lower == 'masked_select': + # masked_select: input, mask -> output (1D) + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + elif op_name_lower == 'select': + # select: input, dim, index -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'dim': 0, 'index': 0} + elif op_name_lower == 'diagonal': + # diagonal: input, offset, dim1, dim2 -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'offset': 0, 'dim1': 0, 'dim2': 1} + elif op_name_lower == 'diag': + # diag: input, diagonal -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'diagonal': 0} + elif op_name_lower == 'diag_embed': + # diag_embed: input, offset, dim1, dim2 -> output + config['shapes'] = [(256, ), (4096, ), (16384, )] + config['extra_args'] = {'offset': 0, 'dim1': -2, 'dim2': -1} + elif op_name_lower == 'pad': + # pad: input, pad -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'pad': (1, 1, 1, 1), 'mode': 'constant', 'value': 0.0} + elif op_name_lower == 'tile': + # tile: input, dims -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'dims': (2, 2)} + elif op_name_lower == 'repeat_interleave': + # repeat_interleave: input, repeats -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'repeats': 2, 'dim': -1} + elif op_name_lower == 'kron': + # kron: input1, input2 -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + elif op_name_lower == 'outer': + # outer: input1, input2 -> output + config['shapes'] = [(256, ), (4096, ), (16384, )] + config['extra_args'] = {} + elif op_name_lower == 'vdot': + # vdot: input1, input2 -> scalar + config['shapes'] = [(256, ), (4096, ), (16384, )] + config['extra_args'] = {} + elif op_name_lower == 'dot': + # dot: input1, input2 -> scalar or 1D + config['shapes'] = [(256, ), (4096, ), (16384, )] + config['extra_args'] = {} + elif op_name_lower == 'polar': + # polar: abs, angle -> complex + # polar不支持float16,需要使用float32 + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + config['dtype'] = torch.float32 + elif op_name_lower == 'atan2': + # atan2: input1, input2 -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + elif op_name_lower == 'hypot': + # hypot: input1, input2 -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + elif op_name_lower == 'fmod': + # fmod: input1, input2 -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + elif op_name_lower == 'slice_scatter': + # slice_scatter: input, dim, src, start, end, step -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'dim': 0, 'start': 0, 'end': 128, 'step': 1} + elif op_name_lower == 'isin': + # isin: elements, test_elements -> output + config['shapes'] = [(256, ), (4096, ), (16384, )] + config['extra_args'] = {} + elif op_name_lower == 'fill': + # fill: input, value -> output + config['shapes'] = BASE_SHAPES + config['extra_args'] = {'value': 1.0} + elif op_name_lower == 'cross_entropy_loss': + # cross_entropy_loss: input, target -> loss + config['shapes'] = [(256, 10), (4096, 100), (16384, 1000)] + config['extra_args'] = {} + elif op_name_lower == 'nll_loss': + # nll_loss: input, target -> loss + config['shapes'] = [(256, 10), (4096, 100), (16384, 1000)] + config['extra_args'] = {} + elif op_name_lower == 'mse_loss': + # mse_loss: input, target -> loss + config['shapes'] = BASE_SHAPES + config['extra_args'] = {} + return config + + # 规约算子:保持原shape,默认不带dim(测试全局规约性能) + reduction_ops = { + 'sum', 'mean', 'max', 'min', 'prod', 'all', 'any', 'amax', 'amin', 'argmax', 'argmin', 'std', 'var', 'var_mean', + 'count_nonzero', 'norm', 'vector_norm' + } + + # 累积算子:需要dim参数 + cumulative_ops = {'cummax', 'cummin', 'cumsum'} + + # 位运算算子:需要整数类型 + bitwise_ops = {'bitwise_and', 'bitwise_or', 'bitwise_not', 'bitwise_xor'} + + # 需要特殊数据类型的二元算子 + binary_ops_special_dtype = { + 'floor_divide': torch.float32, # floor_divide不支持float16,需要使用float32 + 'polar': torch.float32, # polar不支持float16,需要使用float32或float64 + } + + # 需要特殊数据类型的构造函数算子 + constructor_ops_special_dtype = { + 'randperm': torch.int64, # randperm只支持整数类型(int16/int32/int64),使用int64 + } + + # 二元算子:需要两个相同shape的输入 + binary_ops = { + 'add', 'sub', 'mul', 'div', 'pow', 'maximum', 'minimum', 'eq', 'ne', 'lt', 'le', 'gt', 'ge', 'remainder', + 'logical_and', 'logical_or', 'logical_xor', 'fmod', 'atan2', 'hypot', 'rsub', # rsub是二元算子(reverse subtract) + 'isclose', # isclose是二元算子,需要两个输入和rtol/atol参数 + } + + # 三元算子:需要三个输入 + ternary_ops = { + 'lerp', # lerp(input, end, weight) 需要三个输入,weight可以是标量或张量 + } + + # 一元逻辑算子 + unary_logical_ops = { + 'logical_not', # logical_not是一元算子 + } + + # 矩阵运算:需要特定shape + # mm: (M, K) x (K, N) -> (M, N) + # bmm: (B, M, K) x (B, K, N) -> (B, M, N) + # addmm: bias + (M, K) x (K, N) -> (M, N) + # mv: (M, N) x (N,) -> (M,) + matrix_ops = { + 'mm': lambda s: (s[0], s[0]), # 返回 (M, N),内部会创建 (M, K) 和 (K, N) + 'bmm': lambda s: (8, s[0], s[0]), # 返回 batch size + 'addmm': lambda s: (s[0], s[0]), # 返回 (M, N) + 'mv': lambda s: (s[0], s[0]), # 返回 (M, N) + 'matmul': lambda s: (s[0], s[0]), # 返回 (M, N) + } + + # 需要特殊参数的算子 + special_ops = { + 'softmax': {'dim': -1}, 'log_softmax': {'dim': -1}, 'dropout': {'p': 0.5, 'train': + True}, # dropout使用train参数,不是training + 'elu': {'alpha': 1.0}, # elu需要alpha参数 + 'flip': {'dims': (0, )}, # flip需要dims参数,默认在dim=0上翻转 + 'gelu': {}, 'silu': {}, 'relu': {}, 'sigmoid': {}, 'tanh': {}, 'exp': {}, 'log': {}, 'sqrt': {}, 'rsqrt': {}, + 'abs': {}, 'neg': {}, 'reciprocal': {}, 'cos': {}, 'sin': {}, 'erf': {}, 'angle': {}, # 一元算子 + 'glu': {}, # 一元算子 + 'log_sigmoid': {}, # 一元算子 + 'isfinite': {}, # 一元算子 + 'isinf': {}, # 一元算子 + 'isnan': {}, # 一元算子 + 'nan_to_num': {}, # 一元算子 + 'logical_not': {}, # 一元逻辑算子 + 'threshold': {'threshold': 0.5, 'value': 0.0}, # threshold需要threshold和value参数 + 'triu': {'diagonal': 0}, # triu需要diagonal参数 + 'sort': {'dim': -1}, # sort需要dim参数 + 'isclose': {'rtol': 1e-5, 'atol': 1e-8}, # isclose需要rtol和atol参数 + } + + # 需要多个输入的算子 + multi_input_ops = { + 'cat': lambda s: ([torch.randn(s, dtype=DTYPE, device=flag_gems.device) for _ in range(3)], {'dim': 0}), + 'stack': lambda s: ([torch.randn(s, dtype=DTYPE, device=flag_gems.device) for _ in range(3)], {'dim': 0}), + 'hstack': lambda s: ([torch.randn(s, dtype=DTYPE, device=flag_gems.device) for _ in range(3)], {}), + 'vstack': lambda s: ([torch.randn(s, dtype=DTYPE, device=flag_gems.device) for _ in range(3)], {}), + } + + # 构造函数算子(不需要输入) + constructor_ops = { + 'ones', 'zeros', 'eye', 'rand', 'randn', 'full', 'empty', 'ones_like', 'zeros_like', 'rand_like', 'randn_like', + 'full_like', 'empty_like', 'normal', # normal需要mean和std参数,但可以设置默认值 + 'randperm', # randperm需要n参数 + } + + # 需要特殊参数的构造函数算子 + special_constructor_ops = {'arange', 'linspace'} + + # 根据算子类型设置配置 + if op_name_lower in reduction_ops: + config['type'] = 'reduction' + # 默认不带dim参数,测试全局规约性能(与原始测试用例 test_accuracy_sum_without_dim 一致) + # vector_norm需要ord参数,默认使用2 + if op_name_lower == 'vector_norm': + config['extra_args'] = {'ord': 2} + else: + config['extra_args'] = {} + elif op_name_lower in cumulative_ops: + config['type'] = 'cumulative' + # 累积算子需要dim参数,默认使用最后一个维度 + config['extra_args'] = {'dim': -1} + elif op_name_lower in bitwise_ops: + config['type'] = 'bitwise' + # 位运算需要整数类型,使用int16 + config['dtype'] = torch.int16 + elif op_name_lower in binary_ops_special_dtype: + config['type'] = 'binary' + # 需要特殊数据类型的二元算子 + config['dtype'] = binary_ops_special_dtype[op_name_lower] + elif op_name_lower in ternary_ops: + config['type'] = 'ternary' + # 三元算子需要三个输入 + elif op_name_lower in binary_ops: + config['type'] = 'binary' + # 如果isclose在special_ops中,需要合并extra_args + if op_name_lower == 'isclose': + config['extra_args'] = special_ops.get('isclose', {}) + elif op_name_lower in unary_logical_ops: + config['type'] = 'unary' + # 一元逻辑算子使用special_ops中的配置 + if op_name_lower in special_ops: + config['extra_args'] = special_ops[op_name_lower] + elif op_name_lower in matrix_ops: + config['type'] = 'matrix' + config['shapes'] = [matrix_ops[op_name_lower](s) for s in BASE_SHAPES] + elif op_name_lower in special_ops: + config['type'] = 'unary' + config['extra_args'] = special_ops[op_name_lower] + elif op_name_lower in multi_input_ops: + config['type'] = 'multi_input' + config['call_func'] = multi_input_ops[op_name_lower] + elif op_name_lower in special_constructor_ops: + config['type'] = 'special_constructor' + # arange和linspace需要特殊参数,不使用shape + config['shapes'] = BASE_SHAPES + elif op_name_lower in constructor_ops: + config['type'] = 'constructor' + # 构造函数使用shape作为输出shape + config['shapes'] = BASE_SHAPES + # 如果构造函数需要特殊数据类型,设置它 + if op_name_lower in constructor_ops_special_dtype: + config['dtype'] = constructor_ops_special_dtype[op_name_lower] + + return config + + +def create_test_inputs(op_name, shape, config): + """ + 根据算子类型和配置创建测试输入 + + Args: + op_name: 算子名称 + shape: 输入shape + config: 算子配置字典 + + Returns: + tensor或tuple: 测试输入,根据算子类型可能是单个tensor或tuple + """ + # device = flag_gems.device + device = "cpu" + # 获取数据类型,如果配置中指定了则使用配置的,否则使用默认的DTYPE + dtype = config.get('dtype', DTYPE) + + if config['type'] == 'bitwise': + # 位运算需要整数类型,使用randint生成 + if op_name == 'bitwise_not': + # 一元位运算 + return torch.randint(low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, device=device).to(flag_gems.device) + else: + # 二元位运算 + return (torch.randint(low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, + device=device).to(flag_gems.device), + torch.randint(low=-0x7FFF, high=0x7FFF, size=shape, dtype=dtype, + device=device).to(flag_gems.device)) + elif config['type'] == 'ternary': + # 三元算子需要三个输入 + if op_name == 'lerp': + # lerp(input, end, weight) - weight可以是标量或张量,这里使用标量 + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), 0.5 # weight作为标量 + ) + else: + # 其他三元算子 + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), + torch.randn(shape, dtype=dtype, + device=device).to(flag_gems.device), torch.randn(shape, dtype=dtype, + device=device).to(flag_gems.device)) + elif config['type'] == 'binary': + return (torch.randn(shape, dtype=dtype, + device=device).to(flag_gems.device), torch.randn(shape, dtype=dtype, + device=device).to(flag_gems.device)) + elif config['type'] == 'matrix': + if op_name == 'mm': + # mm: (M, K) x (K, N) -> (M, N) + M, N = shape + K = N # 使用N作为K + return (torch.randn((M, K), dtype=DTYPE, + device=device).to(flag_gems.device), torch.randn((K, N), dtype=DTYPE, + device=device).to(flag_gems.device)) + elif op_name == 'bmm': + # bmm: (B, M, K) x (B, K, N) -> (B, M, N) + B, M, N = shape + K = N + return (torch.randn((B, M, K), dtype=DTYPE, + device=device).to(flag_gems.device), torch.randn((B, K, N), dtype=DTYPE, + device=device).to(flag_gems.device)) + elif op_name == 'addmm': + # addmm: bias + (M, K) x (K, N) -> (M, N) + M, N = shape + K = N + return (torch.randn((M, ), dtype=DTYPE, device=device).to(flag_gems.device), # bias + torch.randn((M, K), dtype=DTYPE, device=device).to(flag_gems.device), # mat1 + torch.randn((K, N), dtype=DTYPE, device=device).to(flag_gems.device) # mat2 + ) + elif op_name == 'mv': + # mv: (M, N) x (N,) -> (M,) + M, N = shape + return (torch.randn((M, N), dtype=DTYPE, + device=device).to(flag_gems.device), torch.randn((N, ), dtype=DTYPE, + device=device).to(flag_gems.device)) + elif op_name == 'matmul': + # matmul: 同mm + M, N = shape + K = N + return (torch.randn((M, K), dtype=DTYPE, + device=device).to(flag_gems.device), torch.randn((K, N), dtype=DTYPE, + device=device).to(flag_gems.device)) + elif config['type'] == 'multi_input': + if config['call_func']: + inputs, kwargs = config['call_func'](shape) + return (inputs, kwargs) + elif config['type'] == 'constructor': + # 构造函数:shape作为输出shape参数 + return shape + elif config['type'] == 'special_constructor': + # arange和linspace需要特殊参数,返回None表示需要特殊处理 + return None + elif config['type'] == 'special_fixed_config': + # 特殊固定配置算子,根据算子类型创建相应的输入 + return _create_special_fixed_config_inputs(op_name, shape, config, dtype, device=device) + + # 默认:一元算子或规约算子 + dtype = config.get('dtype', DTYPE) + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) + + +def _create_special_fixed_config_inputs(op_name, shape, config, dtype, device): + """为特殊固定配置算子创建测试输入""" + if op_name == 'batch_norm': + # batch_norm: input (N, C, H, W), weight (C,), bias (C,), running_mean (C,), running_var (C,) + N, C, H, W = shape + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # bias + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # running_mean + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # running_var + ) + elif op_name == 'layer_norm': + # layer_norm: input (N, C, H, W), normalized_shape (C, H, W), weight (C*H*W,), bias (C*H*W,) + N, C, H, W = shape + normalized_shape = (C, H, W) + norm_size = C * H * W + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + normalized_shape, torch.randn((norm_size, ), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randn((norm_size, ), dtype=dtype, device=device).to(flag_gems.device), # bias + ) + elif op_name == 'group_norm': + # group_norm: input (N, C, H, W), num_groups, weight (C,), bias (C,) + N, C, H, W = shape + num_groups = config['extra_args'].get('num_groups', 8) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + num_groups, torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # bias + ) + elif op_name == 'rms_norm': + # rms_norm: input (N, C), weight (C,) + N, C = shape + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # weight + ) + elif op_name == 'conv1d': + # conv1d: input (N, C, L), weight (C_out, C, K), bias (C_out,) + N, C, L = shape + C_out = C # 输出通道数等于输入通道数 + K = 3 # 卷积核大小 + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randn((C_out, C, K), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randn((C_out, ), dtype=dtype, device=device).to(flag_gems.device), # bias (可选) + ) + elif op_name == 'conv2d': + # conv2d: input (N, C, H, W), weight (C_out, C, K, K), bias (C_out,) + N, C, H, W = shape + C_out = C # 输出通道数等于输入通道数 + K = 3 # 卷积核大小 + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randn((C_out, C, K, K), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randn((C_out, ), dtype=dtype, device=device).to(flag_gems.device), # bias (可选) + ) + elif op_name == 'conv_depthwise2d': + # conv_depthwise2d: input (N, C, H, W), weight (C, 1, K, K), bias (C,) + N, C, H, W = shape + K = 3 # 卷积核大小 + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randn((C, 1, K, K), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randn((C, ), dtype=dtype, device=device).to(flag_gems.device), # bias (可选) + ) + elif op_name == 'embedding': + # embedding: weight (num_embeddings, embedding_dim), indices (Batch, M) + # 根据测试用例,indices应该是较小的2D shape,如(Batch, M),而不是直接使用shape + num_embeddings = config['extra_args'].get('num_embeddings', 4096) + embedding_dim = config['extra_args'].get('embedding_dim', 512) + # 限制indices的shape,避免内存溢出 + # 使用较小的batch和sequence length,参考测试用例中的(Batch, M)格式 + batch_size = min(shape[0] if len(shape) > 0 else 4, MAX_BATCH_SIZE) + seq_len = min(shape[-1] if len(shape) > 1 else 128, MAX_SEQ_LEN) + indices_shape = (batch_size, seq_len) + return (torch.randn((num_embeddings, embedding_dim), dtype=dtype, device=device).to(flag_gems.device), # weight + torch.randint(0, num_embeddings, size=indices_shape, dtype=torch.long, + device=device).to(flag_gems.device), # indices + ) + elif op_name == 'gather': + # gather: input, dim, index -> output + dim = config['extra_args'].get('dim', 0) + index_shape = list(shape) + index_shape[dim] = min(shape[dim], MAX_INDEX_SIZE) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + dim, torch.randint(0, shape[dim], size=tuple(index_shape), dtype=torch.long, + device=device).to(flag_gems.device), # index + ) + elif op_name == 'index_select': + # index_select: input, dim, index -> output + dim = config['extra_args'].get('dim', 0) + index_size = min(shape[dim], MAX_INDEX_SIZE) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + dim, torch.randint(0, shape[dim], size=(index_size, ), dtype=torch.long, + device=device).to(flag_gems.device), # index + ) + elif op_name == 'scatter': + # scatter: input, dim, index, src -> output + dim = config['extra_args'].get('dim', 0) + index_shape = list(shape) + index_shape[dim] = min(shape[dim], 64) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + dim, torch.randint(0, shape[dim], size=tuple(index_shape), dtype=torch.long, + device=device).to(flag_gems.device), # index + torch.randn(tuple(index_shape), dtype=dtype, device=device).to(flag_gems.device), # src + ) + elif op_name == 'slice_scatter': + # slice_scatter: input, dim, src, start, end, step -> output + # slice_scatter(input, dim=dim, src=src, start=start, end=end, step=step) + dim = config['extra_args'].get('dim', 0) + start = config['extra_args'].get('start', 0) + end = config['extra_args'].get('end', shape[dim]) + step = config['extra_args'].get('step', 1) + # 计算 src 的形状 + size = shape[dim] + start = start % size + end = end % (size + 1) + if end < start: + end, start = start, end + elif end == start: + end = size + src_size = (end - start + step - 1) // step + src_shape = list(shape) + src_shape[dim] = src_size + return ( + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + dim, + torch.randn(tuple(src_shape), dtype=dtype, device=device).to(flag_gems.device), # src + start, + end, + step, + ) + elif op_name == 'index': + # index: input, indices -> output + # indices 是一个列表,包含多个索引张量,每个对应输入的一个维度 + # 为了简化,我们为每个维度创建一个索引张量 + indices = [] + for i, dim_size in enumerate(shape): + # 为每个维度创建一个较小的索引张量 + index_size = min(dim_size, MAX_INDEX_SIZE) + indices.append( + torch.randint(0, dim_size, size=(index_size, ), dtype=torch.long, device=device).to(flag_gems.device)) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + indices, # indices list + ) + elif op_name == 'index_add': + # index_add: input, dim, index, source -> output + # index_add(input, dim, index, source, alpha=1) + dim = config['extra_args'].get('dim', 0) + index_max = shape[dim] + index_len = min(index_max, MAX_INDEX_SIZE) + index = torch.randperm(index_len, device=device).to(flag_gems.device) # 1D索引张量 + src_shape = list(shape) + src_shape[dim] = index_len + source = torch.randn(tuple(src_shape), dtype=dtype, device=device).to(flag_gems.device) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + dim, index, # index tensor + source, # source tensor + ) + elif op_name == 'index_put': + # index_put: input, indices, values -> output + # index_put(input, indices, values, accumulate=False) + # indices 是一个列表,包含多个索引张量,每个对应输入的一个维度 + indices = [] + for i, dim_size in enumerate(shape): + # 为每个维度创建一个较小的索引张量 + index_size = min(dim_size, MAX_INDEX_SIZE) + indices.append( + torch.randint(0, dim_size, size=(index_size, ), dtype=torch.long, device=device).to(flag_gems.device)) + # values 的形状需要与 indices 广播后的形状匹配 + # 简化处理:使用第一个索引张量的形状作为 values 的形状 + if len(indices) > 0: + values_shape = indices[0].shape + else: + values_shape = (MAX_INDEX_SIZE, ) + values = torch.randn(values_shape, dtype=dtype, device=device) + return (torch.randn(shape, dtype=dtype, device=device), # input + indices, # indices list + values, # values tensor + ) + elif op_name == 'topk': + # topk: input, k, dim -> (values, indices) + k = config['extra_args'].get('k', 10) + dim = config['extra_args'].get('dim', -1) + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'multinomial': + # multinomial: input (probs), num_samples -> indices + # input需要是概率分布,每行和为1 + probs = torch.rand(shape, dtype=dtype, device=device).to(flag_gems.device) + probs = probs / probs.sum(dim=-1, keepdim=True) # 归一化为概率分布 + return probs + elif op_name == 'quantile': + # quantile: input, q, dim -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'where': + # where: condition, x, y -> output + condition = torch.rand(shape, dtype=dtype, device=device).to(flag_gems.device) > 0.5 + return (condition, torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # x + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # y + ) + elif op_name == 'masked_fill': + # masked_fill: input, mask, value -> output + value = config['extra_args'].get('value', 0.0) + return ( + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.rand(shape, dtype=dtype, device=device).to(flag_gems.device) > 0.5, # mask + value, + ) + elif op_name == 'masked_select': + # masked_select: input, mask -> output (1D) + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.rand(shape, dtype=dtype, device=device).to(flag_gems.device) > 0.5, # mask + ) + elif op_name == 'select': + # select: input, dim, index -> output + dim = config['extra_args'].get('dim', 0) + index = config['extra_args'].get('index', 0) + return ( + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + dim, + index, + ) + elif op_name == 'diagonal': + # diagonal: input, offset, dim1, dim2 -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'diag': + # diag: input, diagonal -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'diag_embed': + # diag_embed: input (1D), offset, dim1, dim2 -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'pad': + # pad: input, pad -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'tile': + # tile: input, dims -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name == 'repeat_interleave': + # repeat_interleave: input, repeats, dim -> output + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) # input + elif op_name in ['kron', 'outer', 'vdot', 'dot', 'polar', 'atan2', 'hypot', 'fmod']: + # 二元算子:需要两个输入 + if op_name == 'kron': + # kron 的输出大小是输入大小的乘积,需要限制输入大小避免内存溢出 + # 参考测试用例中的 KRON_SHAPES,使用较小的形状 + if len(shape) == 1: + # 1D: 限制大小避免OOM + kron_shape1 = (min(shape[0], MAX_KRON_1D_SIZE), ) + kron_shape2 = (min(shape[0], MAX_KRON_1D_SIZE), ) + elif len(shape) == 2: + # 2D: 限制大小避免OOM + kron_shape1 = (min(shape[0], MAX_KRON_2D_SIZE), min(shape[1], MAX_KRON_2D_SIZE)) + kron_shape2 = (min(shape[0], MAX_KRON_2D_SIZE), min(shape[1], MAX_KRON_2D_SIZE)) + else: + # 高维: 限制每个维度大小避免OOM + kron_shape1 = tuple(min(s, MAX_KRON_ND_SIZE) for s in shape) + kron_shape2 = tuple(min(s, MAX_KRON_ND_SIZE) for s in shape) + return ( + torch.randn(kron_shape1, dtype=dtype, device=device).to(flag_gems.device), + torch.randn(kron_shape2, dtype=dtype, device=device).to(flag_gems.device), + ) + elif op_name in ['outer', 'vdot', 'dot']: + # 这些算子需要1D输入 + return ( + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), + ) + else: + return ( + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), + ) + elif op_name == 'isin': + # isin: elements, test_elements -> output + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # elements + torch.randn((min(MAX_ISIN_TEST_ELEMENTS, shape[0]), ), dtype=dtype, + device=device).to(flag_gems.device), # test_elements + ) + elif op_name == 'fill': + # fill: input, value -> output + value = config['extra_args'].get('value', 1.0) + return ( + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + value, + ) + elif op_name == 'cross_entropy_loss': + # cross_entropy_loss: input (N, C), target (N,) + N, C = shape + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randint(0, C, size=(N, ), dtype=torch.long, device=device).to(flag_gems.device), # target + ) + elif op_name == 'nll_loss': + # nll_loss: input (N, C), target (N,) + N, C = shape + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randint(0, C, size=(N, ), dtype=torch.long, device=device).to(flag_gems.device), # target + ) + elif op_name == 'mse_loss': + # mse_loss: input, target -> loss + return (torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # input + torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device), # target + ) + else: + # 默认:一元算子 + return torch.randn(shape, dtype=dtype, device=device).to(flag_gems.device) + + +def call_op(op_name, inputs, config, shape=None): + """ + 调用算子 + + 注意:由于性能测试需要多次调用,使用全局启用的flag_gems.enable() + 而不是每次调用use_gems(),避免重复注册错误。 + + Args: + op_name: 算子名称 + inputs: 输入tensor或tuple + config: 算子配置字典 + shape: 可选的shape参数(用于构造函数) + + Returns: + tensor: 算子输出 + """ + # 某些算子需要使用torch.nn.functional + nn_functional_ops = { + 'dropout': torch.nn.functional.dropout, 'elu': torch.nn.functional.elu, 'relu': torch.nn.functional.relu, + 'gelu': torch.nn.functional.gelu, 'silu': torch.nn.functional.silu, 'sigmoid': torch.nn.functional.sigmoid, + 'tanh': torch.nn.functional.tanh, 'log_sigmoid': + torch.nn.functional.logsigmoid, # log_sigmoid在torch.nn.functional中是logsigmoid + } + + if op_name in nn_functional_ops: + op_func = nn_functional_ops[op_name] + else: + # 特殊处理:vector_norm在torch.linalg中 + if op_name == 'vector_norm': + try: + op_func = torch.linalg.vector_norm + except AttributeError: + # 如果torch.linalg不存在,尝试torch.vector_norm + op_func = getattr(torch, 'vector_norm', None) + else: + op_func = getattr(torch, op_name, None) + if op_func is None: + # 尝试下划线版本 + op_func = getattr(torch, op_name + '_', None) + if op_func is None: + # 尝试torch.nn.functional + op_func = getattr(torch.nn.functional, op_name, None) + if op_func is None: + # 尝试torch.linalg(对于其他linalg算子) + if hasattr(torch, 'linalg'): + op_func = getattr(torch.linalg, op_name, None) + + # 对于special_fixed_config类型的算子,op_func可能为None(在_call_special_fixed_config_op中直接调用) + if op_func is None and config.get('type') != 'special_fixed_config': + raise ValueError(f"未找到算子: {op_name}") + + extra_args = config.get('extra_args', {}) + + # 直接调用,因为flag_gems已经全局启用 + return _call_op_impl(op_func, op_name, inputs, config, extra_args, shape) + + +def _call_special_fixed_config_op(op_func, op_name, inputs, config, extra_args, shape=None): + """调用特殊固定配置算子""" + if op_name == 'batch_norm': + # batch_norm(input, weight, bias, running_mean, running_var, ...) + return torch.nn.functional.batch_norm(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], **extra_args) + elif op_name == 'layer_norm': + # layer_norm(input, normalized_shape, weight, bias, ...) + return torch.layer_norm(inputs[0], inputs[1], weight=inputs[2], bias=inputs[3], **extra_args) + elif op_name == 'group_norm': + # group_norm(input, num_groups, weight, bias, ...) + # num_groups 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k != 'num_groups'} + return torch.nn.functional.group_norm(inputs[0], inputs[1], weight=inputs[2], bias=inputs[3], **filtered_args) + elif op_name == 'rms_norm': + # rms_norm(input, weight, ...) + return torch.nn.functional.layer_norm(inputs[0], (inputs[0].shape[-1], ), weight=inputs[1], **extra_args) + elif op_name == 'conv1d': + # conv1d(input, weight, bias=None, ...) + return torch.nn.functional.conv1d(inputs[0], inputs[1], bias=inputs[2], **extra_args) + elif op_name == 'conv2d': + # conv2d(input, weight, bias=None, ...) + return torch.nn.functional.conv2d(inputs[0], inputs[1], bias=inputs[2], **extra_args) + elif op_name == 'conv_depthwise2d': + # conv_depthwise2d(input, weight, bias=None, ...) + return torch.nn.functional.conv2d(inputs[0], inputs[1], bias=inputs[2], groups=inputs[0].shape[1], **extra_args) + elif op_name == 'embedding': + # embedding(indices, weight, ...) + # num_embeddings 和 embedding_dim 只是用于创建 weight 的参数,不应该传递给 embedding 函数 + filtered_args = {k: v for k, v in extra_args.items() if k not in ['num_embeddings', 'embedding_dim']} + return torch.nn.functional.embedding(inputs[1], inputs[0], **filtered_args) + elif op_name == 'gather': + # gather(input, dim, index, ...) + # dim 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k != 'dim'} + return torch.gather(inputs[0], inputs[1], inputs[2], **filtered_args) + elif op_name == 'index_select': + # index_select(input, dim, index, ...) + # dim 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k != 'dim'} + return torch.index_select(inputs[0], inputs[1], inputs[2], **filtered_args) + elif op_name == 'scatter': + # scatter(input, dim, index, src, ...) + # dim 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k != 'dim'} + return torch.scatter(inputs[0], inputs[1], inputs[2], inputs[3], **filtered_args) + elif op_name == 'slice_scatter': + # slice_scatter(input, dim=dim, src=src, start=start, end=end, step=step) + # dim, start, end, step 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k not in ['dim', 'start', 'end', 'step']} + return torch.slice_scatter(inputs[0], dim=inputs[1], src=inputs[2], start=inputs[3], end=inputs[4], + step=inputs[5], **filtered_args) + elif op_name == 'index': + # index(input, indices) -> output + # indices 是一个列表,包含多个索引张量 + return torch.ops.aten.index(inputs[0], inputs[1]) + elif op_name == 'index_add': + # index_add(input, dim, index, source, alpha=1) -> output + # dim 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k != 'dim'} + return torch.index_add(inputs[0], inputs[1], inputs[2], inputs[3], **filtered_args) + elif op_name == 'index_put': + # index_put(input, indices, values, accumulate=False) -> output + # indices 是一个列表,包含多个索引张量 + return torch.index_put(inputs[0], inputs[1], inputs[2], **extra_args) + elif op_name == 'topk': + # topk(input, k, dim, ...) -> (values, indices) + k = extra_args.get('k', 10) + dim = extra_args.get('dim', -1) + result = torch.topk(inputs, k, dim=dim) + return result.values # 只返回values用于性能测试 + elif op_name == 'multinomial': + # multinomial(input, num_samples, ...) + num_samples = extra_args.get('num_samples', 10) + return torch.multinomial(inputs, num_samples, **{k: v for k, v in extra_args.items() if k != 'num_samples'}) + elif op_name == 'quantile': + # quantile(input, q, dim, ...) + q = extra_args.get('q', 0.5) + dim = extra_args.get('dim', -1) + return torch.quantile(inputs, q, dim=dim, **{k: v for k, v in extra_args.items() if k not in ['q', 'dim']}) + elif op_name == 'where': + # where(condition, x, y, ...) + return torch.where(inputs[0], inputs[1], inputs[2], **extra_args) + elif op_name == 'masked_fill': + # masked_fill(input, mask, value, ...) + # value 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k != 'value'} + return torch.masked_fill(inputs[0], inputs[1], inputs[2], **filtered_args) + elif op_name == 'masked_select': + # masked_select(input, mask, ...) + return torch.masked_select(inputs[0], inputs[1], **extra_args) + elif op_name == 'select': + # select(input, dim, index, ...) + # dim 和 index 已经通过位置参数传递,不应该再通过 extra_args 传递 + filtered_args = {k: v for k, v in extra_args.items() if k not in ['dim', 'index']} + return torch.select(inputs[0], inputs[1], inputs[2], **filtered_args) + elif op_name == 'diagonal': + # diagonal(input, offset, dim1, dim2, ...) + offset = extra_args.get('offset', 0) + dim1 = extra_args.get('dim1', 0) + dim2 = extra_args.get('dim2', 1) + return torch.diagonal(inputs, offset=offset, dim1=dim1, dim2=dim2) + elif op_name == 'diag': + # diag(input, diagonal, ...) + diagonal = extra_args.get('diagonal', 0) + return torch.diag(inputs, diagonal=diagonal) + elif op_name == 'diag_embed': + # diag_embed(input, offset, dim1, dim2, ...) + offset = extra_args.get('offset', 0) + dim1 = extra_args.get('dim1', -2) + dim2 = extra_args.get('dim2', -1) + return torch.diag_embed(inputs, offset=offset, dim1=dim1, dim2=dim2) + elif op_name == 'pad': + # pad(input, pad, mode, value, ...) + pad = extra_args.get('pad', (1, 1, 1, 1)) + mode = extra_args.get('mode', 'constant') + value = extra_args.get('value', 0.0) + return torch.nn.functional.pad(inputs, pad, mode=mode, value=value) + elif op_name == 'tile': + # tile(input, dims, ...) + dims = extra_args.get('dims', (2, 2)) + return torch.tile(inputs, dims) + elif op_name == 'repeat_interleave': + # repeat_interleave(input, repeats, dim, ...) + repeats = extra_args.get('repeats', 2) + dim = extra_args.get('dim', -1) + return torch.repeat_interleave(inputs, repeats, dim=dim) + elif op_name == 'kron': + # kron(input1, input2, ...) + return torch.kron(inputs[0], inputs[1], **extra_args) + elif op_name == 'outer': + # outer(input1, input2, ...) + return torch.outer(inputs[0], inputs[1], **extra_args) + elif op_name == 'vdot': + # vdot(input1, input2, ...) + return torch.vdot(inputs[0], inputs[1], **extra_args) + elif op_name == 'dot': + # dot(input1, input2, ...) + return torch.dot(inputs[0], inputs[1], **extra_args) + elif op_name == 'polar': + # polar(abs, angle, ...) + return torch.polar(inputs[0], inputs[1], **extra_args) + elif op_name == 'atan2': + # atan2(input1, input2, ...) + return torch.atan2(inputs[0], inputs[1], **extra_args) + elif op_name == 'hypot': + # hypot(input1, input2, ...) + return torch.hypot(inputs[0], inputs[1], **extra_args) + elif op_name == 'fmod': + # fmod(input1, input2, ...) + return torch.fmod(inputs[0], inputs[1], **extra_args) + elif op_name == 'isin': + # isin(elements, test_elements, ...) + return torch.isin(inputs[0], inputs[1], **extra_args) + elif op_name == 'fill': + # fill(input, value, ...) - 注意:fill是inplace操作 + result = inputs[0].clone() + result.fill_(inputs[1]) + return result + elif op_name == 'cross_entropy_loss': + # cross_entropy_loss(input, target, ...) + return torch.nn.functional.cross_entropy(inputs[0], inputs[1], **extra_args) + elif op_name == 'nll_loss': + # nll_loss(input, target, ...) + return torch.nn.functional.nll_loss(inputs[0], inputs[1], **extra_args) + elif op_name == 'mse_loss': + # mse_loss(input, target, ...) + return torch.nn.functional.mse_loss(inputs[0], inputs[1], **extra_args) + else: + raise ValueError(f"未知的特殊固定配置算子: {op_name}") + + +def _call_op_impl(op_func, op_name, inputs, config, extra_args, shape=None): + """实际的算子调用实现""" + if config['type'] == 'bitwise': + # 位运算:一元或二元 + if op_name == 'bitwise_not': + return op_func(inputs, **extra_args) + else: + return op_func(inputs[0], inputs[1], **extra_args) + elif config['type'] == 'ternary': + # 三元算子 + if op_name == 'lerp': + # lerp(input, end, weight) - weight可以是标量或张量 + return op_func(inputs[0], inputs[1], weight=inputs[2], **extra_args) + else: + return op_func(inputs[0], inputs[1], inputs[2], **extra_args) + elif config['type'] == 'binary': + return op_func(inputs[0], inputs[1], **extra_args) + elif config['type'] == 'matrix': + if op_name == 'addmm': + # addmm(bias, mat1, mat2, ...) + return op_func(inputs[0], inputs[1], inputs[2], **extra_args) + elif op_name in ['bmm', 'mm', 'matmul']: + return op_func(inputs[0], inputs[1], **extra_args) + elif op_name == 'mv': + return op_func(inputs[0], inputs[1], **extra_args) + elif config['type'] == 'multi_input': + input_list, kwargs = inputs + merged_kwargs = {**extra_args, **kwargs} + return op_func(input_list, **merged_kwargs) + elif config['type'] == 'constructor': + # 构造函数使用shape作为参数 + if op_name == 'full_like': + # full_like(input, fill_value) 需要参考tensor和fill_value + ref_tensor = torch.randn(shape, dtype=DTYPE, device=flag_gems.device) + fill_value = 1.0 + return op_func(ref_tensor, fill_value, dtype=DTYPE, device=flag_gems.device, **extra_args) + elif 'like' in op_name: + # ones_like等需要参考tensor(inputs就是shape,需要创建参考tensor) + ref_tensor = torch.randn(shape, dtype=DTYPE, device=flag_gems.device) + return op_func(ref_tensor, **extra_args) + elif op_name == 'eye': + # eye需要矩阵大小,使用shape的第一个维度 + n = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else shape + m = shape[1] if isinstance(shape, tuple) and len(shape) > 1 else n + return op_func(n, m, dtype=DTYPE, device=flag_gems.device, **extra_args) + elif op_name == 'normal': + # normal(mean, std, size) 需要mean和std参数 + mean = 0.0 + std = 1.0 + return op_func(mean, std, shape, dtype=DTYPE, device=flag_gems.device, **extra_args) + elif op_name == 'randperm': + # randperm(n) 需要n参数,使用shape的第一个维度 + # randperm只支持整数类型(int16/int32/int64),使用配置中的dtype或默认int64 + dtype = config.get('dtype', torch.int64) + n = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else (shape if isinstance(shape, int) else 256) + return op_func(n, dtype=dtype, device=flag_gems.device, **extra_args) + elif op_name == 'full': + # full(size, fill_value) 需要fill_value参数 + fill_value = 1.0 + return op_func(shape, fill_value, dtype=DTYPE, device=flag_gems.device, **extra_args) + else: + # ones, zeros, rand, randn, empty等直接使用shape + return op_func(shape, dtype=DTYPE, device=flag_gems.device, **extra_args) + elif config['type'] == 'special_constructor': + # arange和linspace需要特殊参数 + if op_name == 'arange': + # arange(start, end, step, ...) + # 使用shape的第一个维度作为end值 + end = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else ( + shape if isinstance(shape, int) else 256) + return op_func(0, end, dtype=DTYPE, device=flag_gems.device, **extra_args) + elif op_name == 'linspace': + # linspace(start, end, steps, ...) + # 使用shape的第一个维度作为steps值 + steps = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else ( + shape if isinstance(shape, int) else 256) + return op_func(0.0, 1.0, steps, dtype=DTYPE, device=flag_gems.device, **extra_args) + else: + raise ValueError(f"未知的特殊构造函数算子: {op_name}") + elif config['type'] == 'cumulative': + # 累积算子:cummax和cummin返回命名元组(values, indices),cumsum返回tensor + result = op_func(inputs, **extra_args) + if op_name in ['cummax', 'cummin']: + # 返回values部分用于性能测试 + return result.values + return result + elif config['type'] == 'special_fixed_config': + # 特殊固定配置算子 + return _call_special_fixed_config_op(op_func, op_name, inputs, config, extra_args, shape) + elif op_name == 'dropout': + # dropout需要位置参数:dropout(input, p, train) + # 使用torch.nn.functional.dropout + p = extra_args.get('p', 0.5) + train = extra_args.get('train', True) + return torch.nn.functional.dropout(inputs, p, train) + elif op_name == 'flip': + # flip需要位置参数:flip(input, dims) + dims = extra_args.get('dims', (0, )) + return op_func(inputs, dims) + elif op_name == 'threshold': + # threshold需要位置参数:threshold(input, threshold, value) + threshold = extra_args.get('threshold', 0.5) + value = extra_args.get('value', 0.0) + return op_func(inputs, threshold, value) + elif op_name == 'triu': + # triu需要位置参数:triu(input, diagonal) + diagonal = extra_args.get('diagonal', 0) + return op_func(inputs, diagonal) + elif op_name == 'sort': + # sort需要位置参数:sort(input, dim) + dim = extra_args.get('dim', -1) + result = op_func(inputs, dim=dim) + # sort返回(values, indices),只返回values用于性能测试 + return result.values if hasattr(result, 'values') else result[0] + elif op_name == 'vector_norm': + # vector_norm需要ord参数,dim和keepdim可选 + ord = extra_args.get('ord', 2) + dim = extra_args.get('dim', None) + keepdim = extra_args.get('keepdim', False) + if dim is not None: + return op_func(inputs, ord=ord, dim=dim, keepdim=keepdim) + else: + return op_func(inputs, ord=ord, keepdim=keepdim) + elif op_name == 'isclose': + # isclose需要两个输入和rtol/atol参数 + if config['type'] == 'binary': + rtol = extra_args.get('rtol', 1e-5) + atol = extra_args.get('atol', 1e-8) + return op_func(inputs[0], inputs[1], rtol=rtol, atol=atol) + else: + # 如果只有一个输入,创建第二个输入 + inp2 = torch.randn_like(inputs) + rtol = extra_args.get('rtol', 1e-5) + atol = extra_args.get('atol', 1e-8) + return op_func(inputs, inp2, rtol=rtol, atol=atol) + else: + # 一元算子或规约算子 + return op_func(inputs, **extra_args) + + +def test_op_performance(op_name, shape, config): + """ + 测试单个算子的性能 + + 使用triton的do_bench来测量性能,支持各种GPU设备。 + 注意:使用全局启用的flag_gems.enable(),避免每次调用use_gems()导致的重复注册。 + + Args: + op_name: 算子名称 + shape: 测试shape + config: 算子配置字典 + + Returns: + float: 平均耗时(毫秒) + """ + do_bench = _get_do_bench() + _ensure_flag_gems_enabled() + + try: + # 跳过需要特殊参数的算子 + if config.get('type') == 'skip': + pytest.skip(f"算子 {op_name}: {config.get('reason', '需要特殊参数')}") + + # 创建测试输入 + inputs = create_test_inputs(op_name, shape, config) + + # 定义要测试的函数 + if inputs is None and config['type'] == 'special_constructor': + + def test_fn(): + return call_op(op_name, None, config, shape=shape) + + # special_constructor的shape格式化 + if op_name == 'arange': + end = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else ( + shape if isinstance(shape, int) else 256) + shape_str = f"arange(0, {end})" + elif op_name == 'linspace': + steps = shape[0] if isinstance(shape, tuple) and len(shape) > 0 else ( + shape if isinstance(shape, int) else 256) + shape_str = f"linspace(0, 1, {steps})" + else: + shape_str = str(shape) + else: + + def test_fn(): + return call_op(op_name, inputs, config, shape=shape) + + shape_str = _format_shape_str(op_name, inputs, config, shape) + + # 使用do_bench测量性能(返回中位数,单位:毫秒) + avg_time_ms = do_bench(test_fn, warmup=WARMUP, rep=REPETITION, return_mode="median") + avg_time_us = avg_time_ms * 1000 + elapsed_time_ms = avg_time_ms * REPETITION + + # 存储结果 + _store_result(op_name, shape_str, config, avg_time_ms, avg_time_us, elapsed_time_ms) + + # 运行时打印 + print(f" shape={shape_str:<35} avg_time={avg_time_us:>10.2f} us ({avg_time_ms:>8.4f} ms)", flush=True) + + return avg_time_ms + + except Exception as e: + # 记录失败的测试 + shape_str = str(shape) + _store_result(op_name, shape_str, config, None, None, None, error=str(e)) + print(f" shape={shape_str:<35} ERROR: {str(e)}", flush=True) + pytest.skip(f"算子 {op_name} 测试失败: {e}") + + +# 为了pytest兼容性,创建一个通用的测试函数 +@pytest.mark.parametrize("op_name", []) +def test_op_performance_pytest(op_name): + """pytest版本的测试函数(通过parametrize动态生成)""" + ops = parse_op_list() + if op_name not in ops: + pytest.skip(f"未知算子: {op_name}") + + config = get_op_config(op_name) + # 使用第一个shape作为默认测试 + shape = config['shapes'][0] if config['shapes'] else BASE_SHAPES[0] + test_op_performance(op_name, shape, config) + + +def print_summary_report(): + """ + 打印性能测试总结报告 + + 按算子分组显示所有测试结果,包括shape、配置、耗时等信息。 + """ + if not hasattr(pytest, '_op_perf_results') or not pytest._op_perf_results: + return + + results = pytest._op_perf_results + + # 按算子分组 + op_groups = {} + for r in results: + op = r['op'] + if op not in op_groups: + op_groups[op] = [] + op_groups[op].append(r) + + print("\n" + "=" * 120) + print(" " * 40 + "算子性能测试报告") + print("=" * 120) + + # 打印每个算子的结果 + for op_name in sorted(op_groups.keys()): + op_results = op_groups[op_name] + print(f"\n算子: {op_name.upper()}") + print("=" * 120) + print(f"{'Shape':<35} {'Config':<20} {'Avg Time (us)':<18} {'Avg Time (ms)':<18} {'Status':<15}") + print("=" * 120) + + for r in op_results: + shape_str = r['shape'][:34] if len(r['shape']) > 34 else r['shape'] + config_str = r['config'][:19] if len(r['config']) > 19 else r['config'] + + if r.get('error'): + print(f"{shape_str:<35} {config_str:<20} {'N/A':<18} {'N/A':<18} {'ERROR':<15}") + print(f" Error: {r['error']}") + elif r['avg_time_us'] is not None: + print( + f"{shape_str:<35} {config_str:<20} {r['avg_time_us']:>15.2f} {r['avg_time_ms']:>15.4f} {'OK':<15}" + ) + else: + print(f"{shape_str:<35} {config_str:<20} {'N/A':<18} {'N/A':<18} {'SKIPPED':<15}") + + print("\n" + "=" * 120) + print(f"测试配置:") + print(f" * Warmup: {WARMUP}") + print(f" * Repetition: {REPETITION}") + print(f" * Dtype: {DTYPE}") + _ensure_flag_gems_enabled() + print(f" * Device: {flag_gems.device}") + print(f" * Benchmark Method: triton.do_bench") + print("=" * 120 + "\n") + + +def print_csv_report(filename='performance_report.csv'): + """ + 将性能测试报告输出为CSV格式 + + Args: + filename: CSV文件名,默认为'performance_report.csv' + + Note: + 使用&作为分隔符,避免shape字段中的逗号导致格式混乱。 + """ + import csv + import os + + if not hasattr(pytest, '_op_perf_results') or not pytest._op_perf_results: + print("没有性能测试结果可输出") + return + + results = pytest._op_perf_results + + # CSV文件路径 + csv_path = os.path.join(os.path.dirname(__file__), filename) + + # 写入CSV文件 + with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile: + fieldnames = [ + 'Operator', 'Shape', 'Config', 'Dtype', 'Type', 'Avg Time (us)', 'Avg Time (ms)', 'Elapsed Time (ms)', + 'Status', 'Error' + ] + # 使用&作为分隔符,避免shape字段中的逗号导致格式混乱 + writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter='&', quoting=csv.QUOTE_MINIMAL) + + # 写入表头 + writer.writeheader() + + # 写入数据 + for r in results: + status = 'OK' + if r.get('error'): + status = 'ERROR' + elif r['avg_time_us'] is None: + status = 'SKIPPED' + + # 将所有字段转换为字符串,确保CSV格式正确 + # csv模块会自动为包含逗号的字段添加引号 + writer.writerow({ + 'Operator': + str(r['op']), 'Shape': + str(r['shape']), # 包含逗号的字段会被自动用引号括起来 + 'Config': + str(r['config']), # 包含逗号的字段会被自动用引号括起来 + 'Dtype': + str(r['dtype']), 'Type': + str(r['type']), 'Avg Time (us)': + f"{r['avg_time_us']:.2f}" if r['avg_time_us'] is not None else 'N/A', 'Avg Time (ms)': + f"{r['avg_time_ms']:.4f}" if r['avg_time_ms'] is not None else 'N/A', 'Elapsed Time (ms)': + f"{r['elapsed_time_ms']:.4f}" if r.get('elapsed_time_ms') is not None else 'N/A', 'Status': + str(status), 'Error': + str(r.get('error', '')) + }) + + print(f"\n性能测试报告已保存为CSV格式: {csv_path}") + + +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + """pytest session结束时打印总结报告并输出CSV""" + print_summary_report() + print_csv_report() + + +if __name__ == "__main__": + # 启用FlagGems(会自动检测可用的GPU设备) + _ensure_flag_gems_enabled() + + # 初始化结果列表 + pytest._op_perf_results = [] + + # 获取所有算子 + ops = parse_op_list() + print(f"找到 {len(ops)} 个算子,开始性能测试...\n") + + # 运行所有测试 + for op_name in ops: + config = get_op_config(op_name) + print(f"测试算子: {op_name} (类型: {config['type']})...") + + # 跳过需要特殊参数的算子 + if config.get('type') == 'skip': + print(f" 跳过: {config.get('reason', '需要特殊参数')}") + continue + + for shape in config.get('shapes', []): + try: + test_op_performance(op_name, shape, config) + except Exception as e: + print(f" 警告: {op_name} shape={shape} 测试失败: {e}") + + # 打印性能测试报告 + print("\n" + "=" * 100) + print_summary_report() + + # 输出CSV格式报告 + print_csv_report() diff --git a/third_party/tsingmicro/crt/CMakeLists.txt b/third_party/tsingmicro/crt/CMakeLists.txt index 14ad6eb241..f879f691f1 100644 --- a/third_party/tsingmicro/crt/CMakeLists.txt +++ b/third_party/tsingmicro/crt/CMakeLists.txt @@ -106,10 +106,10 @@ else() -DCONFIG_TX8_KERNEL_PRINTF_SUPPORT=1 ) - if (USE_PROFILE) + if (ENABLE_PROFILING) # 只需要定义,不需要链接,编译kernel.so的时候链接profile set(RISCV_COMPILE_OPTIONS ${RISCV_COMPILE_OPTIONS} - -DUSE_PROFILE + -DENABLE_PROFILING ) endif() @@ -119,6 +119,12 @@ else() ) endif() + if (ENABLE_SYNCHRONOUS_INTRINSIC) + set(RISCV_COMPILE_OPTIONS ${RISCV_COMPILE_OPTIONS} + -DENABLE_SYNCHRONOUS_INTRINSIC + ) + endif() + if(CMAKE_BUILD_TYPE STREQUAL "Debug") set(RISCV_COMPILE_OPTIONS ${RISCV_COMPILE_OPTIONS} -DRT_USING_DEBUG diff --git a/third_party/tsingmicro/crt/include/Tx81/tx81.h b/third_party/tsingmicro/crt/include/Tx81/tx81.h index 36ce6eb902..48ffde8674 100644 --- a/third_party/tsingmicro/crt/include/Tx81/tx81.h +++ b/third_party/tsingmicro/crt/include/Tx81/tx81.h @@ -17,8 +17,6 @@ #include #include -#include "profiler.h" - typedef enum { UNKNOWN = 0, SPM = 1, @@ -91,4 +89,10 @@ void RT_ASSERT(bool value); #define INTRNISIC_RUN_SWITCH #endif +#ifdef ENABLE_SYNCHRONOUS_INTRINSIC +#define SYNCHRONOUS_INTRINSIC_SWITCH TsmWaitfinish() +#else +#define SYNCHRONOUS_INTRINSIC_SWITCH +#endif + #endif // CRT_TARGET_TX81_H diff --git a/third_party/tsingmicro/crt/lib/Tx81/abs.c b/third_party/tsingmicro/crt/lib/Tx81/abs.c index 2465c10d8f..06bd1f2872 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/abs.c +++ b/third_party/tsingmicro/crt/lib/Tx81/abs.c @@ -27,6 +27,7 @@ void __AbsVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/arith.c b/third_party/tsingmicro/crt/lib/Tx81/arith.c index 8d88114dc5..2bc6020d64 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/arith.c +++ b/third_party/tsingmicro/crt/lib/Tx81/arith.c @@ -51,6 +51,7 @@ void __SubVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -95,6 +96,7 @@ void __DivVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -117,6 +119,7 @@ void __AddVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -139,6 +142,7 @@ void __SubVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -161,6 +165,7 @@ void __MulVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -183,6 +188,7 @@ void __DivVS(uint64_t *src0, uint32_t src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -205,6 +211,7 @@ void __MaxVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -227,6 +234,7 @@ void __MinVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_in.c b/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_in.c index e3b2139bb8..8f1abbeb15 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_in.c +++ b/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_in.c @@ -16,5 +16,6 @@ void __AtomicBarrierIn() { #ifdef USE_SIM_MODE #else atomic_barrier_in(); + SYNCHRONOUS_INTRINSIC_SWITCH; #endif } diff --git a/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_out.c b/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_out.c index 81e5bb3d6a..c326260af8 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_out.c +++ b/third_party/tsingmicro/crt/lib/Tx81/atomic_barrier_out.c @@ -15,5 +15,6 @@ void __AtomicBarrierOut() { #ifdef USE_SIM_MODE #else atomic_barrier_out(); + SYNCHRONOUS_INTRINSIC_SWITCH; #endif } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c index abaec8c6be..b9a0fdf3e2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_fp16.c @@ -26,6 +26,7 @@ void __BF16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c index 0b78bf116a..aca8a4db08 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int16.c @@ -27,6 +27,7 @@ void __BF16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c index 8038c9cda5..fe10ed82be 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int32.c @@ -27,6 +27,7 @@ void __BF16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c index e7d18b3a0e..837d1fbbf0 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_int8.c @@ -26,6 +26,7 @@ void __BF16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c index 2f7987124e..ddf4a90ac6 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bf16_tf32.c @@ -26,6 +26,7 @@ void __BF16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bilinear.c b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c index 92257c8683..4f3210b1d0 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bilinear.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bilinear.c @@ -34,6 +34,7 @@ void __Bilinear(uint64_t *src, uint64_t *dst, uint16_t src_n, uint16_t src_h, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c index b9d7abe7b6..6aeb4023a3 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/bit2fp.c @@ -31,6 +31,7 @@ void __Bit2Fp(uint64_t *src, uint64_t *target, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/concat.c b/third_party/tsingmicro/crt/lib/Tx81/concat.c index e13c7dc716..1c6db4ae1b 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/concat.c +++ b/third_party/tsingmicro/crt/lib/Tx81/concat.c @@ -35,6 +35,7 @@ void __Concat(uint64_t *src1, uint16_t src1_n, uint16_t src1_h, uint16_t src1_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/conv.c b/third_party/tsingmicro/crt/lib/Tx81/conv.c index 0a5669e275..9b66c842fe 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/conv.c +++ b/third_party/tsingmicro/crt/lib/Tx81/conv.c @@ -61,6 +61,7 @@ void __Conv(int64_t opType, int64_t *srcAct, int64_t *srcActDims, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/cos.c b/third_party/tsingmicro/crt/lib/Tx81/cos.c index c545ff7389..4bc06c6f50 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/cos.c +++ b/third_party/tsingmicro/crt/lib/Tx81/cos.c @@ -27,6 +27,7 @@ void __Cos(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/count.c b/third_party/tsingmicro/crt/lib/Tx81/count.c index 1c6e58c97b..4bd6fbd2f2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/count.c +++ b/third_party/tsingmicro/crt/lib/Tx81/count.c @@ -28,6 +28,7 @@ void __Count(uint64_t *src, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/exp.c b/third_party/tsingmicro/crt/lib/Tx81/exp.c index eb3090fd8c..8b0d691132 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/exp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/exp.c @@ -27,6 +27,7 @@ void __Exp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/explp.c b/third_party/tsingmicro/crt/lib/Tx81/explp.c index 337fbdabd0..9219babb98 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/explp.c +++ b/third_party/tsingmicro/crt/lib/Tx81/explp.c @@ -27,6 +27,7 @@ void __Explp(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c index 8262aa01e7..eb872a9e1c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_bf16.c @@ -27,6 +27,7 @@ void __FP16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c index 7849c8687f..d8b87fab99 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int16.c @@ -27,6 +27,7 @@ void __FP16_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c index 1cda5410b2..b4f6357594 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int32.c @@ -27,6 +27,7 @@ void __FP16_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c index 663691b9ec..deebeabf97 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_int8.c @@ -27,6 +27,6 @@ void __FP16_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c index c65169b7d6..1ad8b5eb00 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp16_tf32.c @@ -26,6 +26,7 @@ void __FP16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c index 1cb2e71c71..bdb1574f08 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_bf16.c @@ -27,6 +27,7 @@ void __FP32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c index f771070ecd..9d5e06c47f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int16.c @@ -27,6 +27,7 @@ void __FP32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c index 82443f9d4f..477f32d36f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/fp32_int8.c @@ -27,6 +27,7 @@ void __FP32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/gelu_none.c b/third_party/tsingmicro/crt/lib/Tx81/gelu_none.c index ff9cff697c..1537a2f287 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gelu_none.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gelu_none.c @@ -16,4 +16,5 @@ void __GeluNone(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { INTRNISIC_RUN_SWITCH; op_gelu_none(src, dst, elem_count, (Data_Format)fmt); + SYNCHRONOUS_INTRINSIC_SWITCH; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/gelu_tanh.c b/third_party/tsingmicro/crt/lib/Tx81/gelu_tanh.c index 9582f1a20f..28b1bd64ad 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gelu_tanh.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gelu_tanh.c @@ -16,4 +16,5 @@ void __GeluTanh(uint64_t *src, uint64_t *imm, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { INTRNISIC_RUN_SWITCH; op_gelu_tanh(src, imm, dst, elem_count, fmt); + SYNCHRONOUS_INTRINSIC_SWITCH; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/gemm.c b/third_party/tsingmicro/crt/lib/Tx81/gemm.c index dfc521d5ce..f3892c3851 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/gemm.c +++ b/third_party/tsingmicro/crt/lib/Tx81/gemm.c @@ -54,4 +54,5 @@ void __Gemm(int64_t *srcA, int64_t *srcB, int64_t *srcBias, int64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/img2col.c b/third_party/tsingmicro/crt/lib/Tx81/img2col.c index d5001df424..046cd322bd 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/img2col.c +++ b/third_party/tsingmicro/crt/lib/Tx81/img2col.c @@ -37,6 +37,7 @@ void __Img2col(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c index 3c303e0798..4d90ca8fe2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_bf16.c @@ -27,6 +27,7 @@ void __INT16_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c index 44e8017160..4696fd7d4f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp16.c @@ -26,6 +26,7 @@ void __INT16_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c index be3edd2b42..e7a362617a 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_fp32.c @@ -27,6 +27,6 @@ void __INT16_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c index 44b9ef8f8c..eaf7fc8f1b 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int16_tf32.c @@ -27,6 +27,6 @@ void __INT16_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c index ce4ab6e85c..aaca19a6be 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_bf16.c @@ -27,6 +27,6 @@ void __INT32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c index 82b239ca87..87c51b7757 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_fp16.c @@ -28,6 +28,6 @@ void __INT32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c index f8a8debd0d..c8b8888209 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int32_tf32.c @@ -27,6 +27,6 @@ void __INT32_TF32(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c index 79000b83f6..f7c6798871 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_bf16.c @@ -27,6 +27,7 @@ void __INT8_BF16(uint64_t *src, uint64_t *dst, uint32_t zp, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c index 3931d1c6a2..ad7019a3be 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_fp16.c @@ -27,6 +27,7 @@ void __INT8_FP16(uint64_t *src, uint64_t *dst, uint32_t zp, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c index 00c2cb629e..a9174b7a52 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/int8_tf32.c @@ -27,6 +27,7 @@ void __INT8_TF32(uint64_t *src, uint64_t *dst, uint32_t zp, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c index c678da722c..378f7cd08c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/leakyrelu.c @@ -29,6 +29,6 @@ void __Leakyrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/ln.c b/third_party/tsingmicro/crt/lib/Tx81/ln.c index 20f537388b..172a332b02 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/ln.c +++ b/third_party/tsingmicro/crt/lib/Tx81/ln.c @@ -27,6 +27,6 @@ void __Ln(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/log2.c b/third_party/tsingmicro/crt/lib/Tx81/log2.c index 72baa50eac..5a46fcb4e3 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/log2.c +++ b/third_party/tsingmicro/crt/lib/Tx81/log2.c @@ -27,6 +27,6 @@ void __Log2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/logic.c b/third_party/tsingmicro/crt/lib/Tx81/logic.c index f9e6dd152c..9b9c6dff29 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/logic.c +++ b/third_party/tsingmicro/crt/lib/Tx81/logic.c @@ -29,7 +29,7 @@ void __AndVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -51,6 +51,7 @@ void __OrVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -73,6 +74,7 @@ void __XorVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut16.c b/third_party/tsingmicro/crt/lib/Tx81/lut16.c index 151dfc7324..fb8f30b631 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/lut16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/lut16.c @@ -30,6 +30,6 @@ void __Lut16(uint64_t *src, uint64_t *dst, uint64_t *lut16, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/lut32.c b/third_party/tsingmicro/crt/lib/Tx81/lut32.c index 05dcdf006c..ca7c26c656 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/lut32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/lut32.c @@ -30,6 +30,6 @@ void __Lut32(uint64_t *src, uint64_t *dst, uint64_t *lut32, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/mask_move.c b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c index 8d4b050966..856a35ad6a 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mask_move.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mask_move.c @@ -27,4 +27,5 @@ void __MaskMove(uint64_t *src, uint64_t *target, uint32_t elem_count, elem_count, (Data_Format)fmt); TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/mirror.c b/third_party/tsingmicro/crt/lib/Tx81/mirror.c index 3e5b26b285..896dc16d9e 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mirror.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mirror.c @@ -32,6 +32,6 @@ void __Mirror(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); - + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/mxfp_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/mxfp_bf16.c index 98b9bdfecd..cc57b39003 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mxfp_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mxfp_bf16.c @@ -70,6 +70,7 @@ void __FP8E5M2_BF16(uint8_t *src, uint16_t *dst, uint32_t elem_count) { bf16_exponent | // Exponent at bits 14-7 (mantissa << 5); // Mantissa at bits 6-5 (bits 4-0 zero) } + SYNCHRONOUS_INTRINSIC_SWITCH; } /** @@ -130,6 +131,7 @@ void __FP8E4M3_BF16(uint8_t *src, uint16_t *dst, uint32_t elem_count) { (bf16_exponent << 7) | // Exponent at bits 14-7 (mantissa << 4); // Mantissa at bits 6-4 (bits 3-0 zero) } + SYNCHRONOUS_INTRINSIC_SWITCH; } /** @@ -285,4 +287,5 @@ void __FP4E2M1_BF16(uint8_t *src, uint16_t *dst, uint32_t elem_count) { dst[2 * i] = result; } } + SYNCHRONOUS_INTRINSIC_SWITCH; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_bf16.c index 3b4e5ca9c1..57d0954f64 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_bf16.c @@ -63,5 +63,6 @@ void __mxfpScaleBF16(uint16_t *value, uint8_t *scale, uint16_t *dst, (uint64_t)block_dst, scaling_block_size, RND_NEAREST_EVEN, Fmt_BF16); TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; } } diff --git a/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_fp16.c index 365605f3cb..ce193c54ff 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/mxfp_scale_fp16.c @@ -86,5 +86,6 @@ void __mxfpScaleFP16(uint16_t *value, uint8_t *scale, uint16_t *dst, (uint64_t)block_dst, scaling_block_size, RND_NEAREST_EVEN, Fmt_FP16); TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; } } diff --git a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c index dd4227c209..46ab557c4f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c +++ b/third_party/tsingmicro/crt/lib/Tx81/nchw2nhwc.c @@ -31,6 +31,7 @@ void __Nchw2nhwc(uint64_t *src, uint64_t *dst, int32_t *src_shape, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c index 6c31bb3fbc..022dd7b0de 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c +++ b/third_party/tsingmicro/crt/lib/Tx81/nhwc2nchw.c @@ -31,6 +31,7 @@ void __Nhwc2nchw(uint64_t *src, uint64_t *dst, int32_t *src_shape, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/op_gelu.c b/third_party/tsingmicro/crt/lib/Tx81/op_gelu.c index ba6aed6a2d..42e6498af4 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/op_gelu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/op_gelu.c @@ -509,6 +509,7 @@ void op_gelu_none(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint8_t *out_ddr = (uint8_t *)get_spm_memory_mapping((uint64_t)(dst)); get_erf_value(in_ddr, out_ddr, elem_count, fmt); + SYNCHRONOUS_INTRINSIC_SWITCH; #ifdef USING_RISCV csi_dcache_clean_range((uint64_t *)out_ddr, elem_count * fmt); @@ -519,4 +520,5 @@ void op_gelu_tanh(uint64_t *src, uint64_t *imm, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { TsmWaitfinish(); get_tanh_value(src, imm, dst, elem_count, fmt); + SYNCHRONOUS_INTRINSIC_SWITCH; } diff --git a/third_party/tsingmicro/crt/lib/Tx81/op_reduce_mul_impl.c b/third_party/tsingmicro/crt/lib/Tx81/op_reduce_mul_impl.c index 264801a547..b1f572f06d 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/op_reduce_mul_impl.c +++ b/third_party/tsingmicro/crt/lib/Tx81/op_reduce_mul_impl.c @@ -52,6 +52,7 @@ void op_reduce_mul_impl(void *in, void *out, Data_Shape shape, // Dispatch the command to accelerator TsmExecute(&data_move_inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. TsmDeleteDataMove(data_move); @@ -83,6 +84,7 @@ void op_reduce_mul_impl(void *in, void *out, Data_Shape shape, arith->MulVV(&arithIns, arith_src0_addr, arith_src1_addr, arith_dst_addr, align_val, RND_NEAREST_EVEN, fmt); TsmExecute(&arithIns); + SYNCHRONOUS_INTRINSIC_SWITCH; } } @@ -98,6 +100,7 @@ void op_reduce_mul_impl(void *in, void *out, Data_Shape shape, arith->MulVV(&arithIns, arith_src0_addr, arith_src1_addr, arith_dst_addr, stride, RND_NEAREST_EVEN, fmt); TsmExecute(&arithIns); + SYNCHRONOUS_INTRINSIC_SWITCH; } } } else if (reduce_dim == 1) { @@ -132,6 +135,7 @@ void op_reduce_mul_impl(void *in, void *out, Data_Shape shape, // Dispatch the command to accelerator TsmExecute(&data_move_inst); + SYNCHRONOUS_INTRINSIC_SWITCH; for (int32_t w_index = 1; w_index < w; w_index++) { uint64_t arith_src0_addr = dst_out_addr; @@ -144,6 +148,7 @@ void op_reduce_mul_impl(void *in, void *out, Data_Shape shape, arith->MulVV(&arithIns, arith_src0_addr, arith_src1_addr, arith_dst_addr, cx_align, RND_NEAREST_EVEN, fmt); TsmExecute(&arithIns); + SYNCHRONOUS_INTRINSIC_SWITCH; } } // Destroy the command buffer. diff --git a/third_party/tsingmicro/crt/lib/Tx81/pad.c b/third_party/tsingmicro/crt/lib/Tx81/pad.c index 4ee7c352b2..a82f0c3689 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/pad.c +++ b/third_party/tsingmicro/crt/lib/Tx81/pad.c @@ -34,6 +34,7 @@ void __Pad(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/pow2.c b/third_party/tsingmicro/crt/lib/Tx81/pow2.c index 2b131ab60c..d15fcace2a 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/pow2.c +++ b/third_party/tsingmicro/crt/lib/Tx81/pow2.c @@ -27,6 +27,7 @@ void __Pow2(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/print.c b/third_party/tsingmicro/crt/lib/Tx81/print.c index ae5b4085d8..8bab5aaf40 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/print.c +++ b/third_party/tsingmicro/crt/lib/Tx81/print.c @@ -19,7 +19,7 @@ void __Print(const char *__restrict fmt, ...) { // FIXME: va_list memory layout is specific to the platform. #ifndef USE_SIM_MODE - tsm_ep_log(__FILE__, __func__, __LINE__, KCORE_LOG_ERROR, fmt, args); + _tsm_ep_log(__FILE__, __func__, __LINE__, KCORE_LOG_ERROR, fmt, args); #else vprintf(fmt, args); #endif diff --git a/third_party/tsingmicro/crt/lib/Tx81/randgen.c b/third_party/tsingmicro/crt/lib/Tx81/randgen.c index 77f1c4ea6e..2e5ce59c8f 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/randgen.c +++ b/third_party/tsingmicro/crt/lib/Tx81/randgen.c @@ -30,6 +30,7 @@ void __RandGen(uint64_t *src0, uint64_t *src1, uint64_t *dst0, uint64_t *dst1, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/relation.c b/third_party/tsingmicro/crt/lib/Tx81/relation.c index 57d5a8d1cf..3ecc729ea2 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/relation.c +++ b/third_party/tsingmicro/crt/lib/Tx81/relation.c @@ -29,6 +29,7 @@ void __BoolEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -95,6 +96,7 @@ void __BoolGreaterVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -117,6 +119,7 @@ void __BoolLessEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -139,6 +142,7 @@ void __BoolLessThenVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -161,6 +165,7 @@ void __EqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -183,6 +188,7 @@ void __UnEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -205,6 +211,7 @@ void __GreaterEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -227,6 +234,7 @@ void __GreaterVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -249,6 +257,7 @@ void __LessEqualVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -271,6 +280,7 @@ void __LessThenVV(uint64_t *src0, uint64_t *src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -293,6 +303,7 @@ void __BoolEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -315,6 +326,7 @@ void __BoolUnEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -337,6 +349,7 @@ void __BoolGreaterEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -359,6 +372,7 @@ void __BoolGreaterVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -381,6 +395,7 @@ void __BoolLessEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -403,6 +418,7 @@ void __BoolLessThenVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -425,6 +441,7 @@ void __EqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -447,6 +464,7 @@ void __UnEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -469,6 +487,7 @@ void __GreaterEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -491,6 +510,7 @@ void __GreaterVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -513,6 +533,7 @@ void __LessEqualVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } @@ -535,6 +556,7 @@ void __LessThenVS(uint64_t *src0, uint32_t src1, uint64_t *dst, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/relu.c b/third_party/tsingmicro/crt/lib/Tx81/relu.c index 83ad18f3e5..8e4050d3ad 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/relu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/relu.c @@ -27,6 +27,7 @@ void __Relu(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate180.c b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c index 2f90834144..ea7dad21dd 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rotate180.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate180.c @@ -32,6 +32,7 @@ void __Rotate180(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate270.c b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c index 197e83bc1c..f5698d31e8 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rotate270.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate270.c @@ -32,6 +32,7 @@ void __Rotate270(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/rotate90.c b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c index 6415eec838..bf135686ee 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rotate90.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rotate90.c @@ -32,6 +32,7 @@ void __Rotate90(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c b/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c index 2deedf9830..e78d400b94 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c +++ b/third_party/tsingmicro/crt/lib/Tx81/rsqrt.c @@ -29,6 +29,7 @@ void __RsqrtVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/satrelu.c b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c index 99067f88c2..fd378f0460 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/satrelu.c +++ b/third_party/tsingmicro/crt/lib/Tx81/satrelu.c @@ -29,6 +29,7 @@ void __Satrelu(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c index 26e499e420..1021adafee 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c +++ b/third_party/tsingmicro/crt/lib/Tx81/sigmoid.c @@ -29,6 +29,7 @@ void __Sigmoid(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/sin.c b/third_party/tsingmicro/crt/lib/Tx81/sin.c index dbd820a09b..67708a493e 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/sin.c +++ b/third_party/tsingmicro/crt/lib/Tx81/sin.c @@ -27,6 +27,7 @@ void __Sin(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/softplus.c b/third_party/tsingmicro/crt/lib/Tx81/softplus.c index 020bc76a9a..3edcb1a955 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/softplus.c +++ b/third_party/tsingmicro/crt/lib/Tx81/softplus.c @@ -30,6 +30,7 @@ void __Softplus(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/sqrt.c b/third_party/tsingmicro/crt/lib/Tx81/sqrt.c index ee84f37cad..dccd4c3344 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/sqrt.c +++ b/third_party/tsingmicro/crt/lib/Tx81/sqrt.c @@ -28,6 +28,7 @@ void __SqrtVV(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tanh.c b/third_party/tsingmicro/crt/lib/Tx81/tanh.c index c9da38e682..02f7294d24 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tanh.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tanh.c @@ -27,6 +27,7 @@ void __Tanh(uint64_t *src, uint64_t *dst, uint32_t elem_count, uint16_t fmt) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c index 815b4b670c..f3d256ebdb 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tensornorm.c @@ -32,6 +32,7 @@ void __TensorNorm(uint64_t *src, uint16_t src_n, uint16_t src_h, uint16_t src_w, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c index 9c816c945b..1d5cb13729 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_bf16.c @@ -27,6 +27,7 @@ void __TF32_BF16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c index d2dabeac07..326b1bfc98 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp16.c @@ -26,6 +26,7 @@ void __TF32_FP16(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c index 2d922f1949..c460be2ac8 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_fp32.c @@ -26,6 +26,7 @@ void __TF32_FP32(uint64_t *src, uint64_t *dst, uint32_t elem_count) { // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c index 52aeab697f..141ee97e58 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int16.c @@ -27,6 +27,7 @@ void __TF32_INT16(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c index 1545670616..2f7c5b963c 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int32.c @@ -27,6 +27,7 @@ void __TF32_INT32(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c index d427e16e20..6fc47bdf90 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c +++ b/third_party/tsingmicro/crt/lib/Tx81/tf32_int8.c @@ -27,6 +27,7 @@ void __TF32_INT8(uint64_t *src, uint64_t *dst, uint32_t elem_count, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/crt/lib/Tx81/transpose.c b/third_party/tsingmicro/crt/lib/Tx81/transpose.c index ba33b97f35..fc13e082a3 100644 --- a/third_party/tsingmicro/crt/lib/Tx81/transpose.c +++ b/third_party/tsingmicro/crt/lib/Tx81/transpose.c @@ -31,6 +31,7 @@ void __Transpose(uint64_t *src, uint64_t *dst, int32_t *src_shape, // Dispatch the command to accelerator TsmExecute(&inst); + SYNCHRONOUS_INTRINSIC_SWITCH; // Destroy the command buffer. } diff --git a/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h b/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h index 8a8f131f3e..c9c5d76cce 100644 --- a/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h +++ b/third_party/tsingmicro/include/triton-shared/Analysis/MaskAnalysis.h @@ -90,6 +90,13 @@ struct MaskState { LogicalResult addStates(const MaskState &lhsState, const MaskState &rhsState, Location loc, OpBuilder &builder); + LogicalResult subStateScalar(const MaskState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder); + + LogicalResult subStates(const MaskState &lhsState, const MaskState &rhsState, + Location loc, OpBuilder &builder); + LogicalResult minStateScalar(const MaskState &lhsState, const MaskState &rhsState, Location loc, OpBuilder &builder); @@ -117,6 +124,11 @@ struct MaskState { // and end, dims remains unchanged, and scalar is empty. LogicalResult parseAdd(arith::AddIOp addOp, const Location loc, OpBuilder &builder); + // Operand is the result of subi + // One and only one of the operands should be a scalar. Decrement both start + // and end, dims remains unchanged, and scalar is empty. + LogicalResult parseSub(arith::SubIOp subOp, const Location loc, + OpBuilder &builder); // Operand is the result of andi // Each of the result state dims is smaller of the two operands' dims. // Insert instruction if needed to get new dims. diff --git a/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h b/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h index b5505c7ad7..38f5018741 100644 --- a/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h +++ b/third_party/tsingmicro/include/triton-shared/AnalysisStructured/PtrAnalysis.h @@ -65,6 +65,10 @@ struct PtrState { LogicalResult addState(const PtrState &lhsState, const PtrState &rhsState, Operation *op, OpBuilder &builder); + // Process subtraction of two PtrStates + LogicalResult subState(const PtrState &lhsState, const PtrState &rhsState, + Operation *op, OpBuilder &builder); + // Process multiplication of two PtrStates LogicalResult mulState(const PtrState &lhsState, const PtrState &rhsState, Operation *op, OpBuilder &builder); @@ -135,6 +139,19 @@ class PtrAnalysis { LogicalResult visitOperandAdd(arith::AddIOp addOp, PtrState &state, const Location loc, OpBuilder &builder); + // Operand is the result of arith.subi. Process both arguments and insert any + // arith.subi instruction as needed. + // Main assumptions: + // Only one of lhsState and rhsState has source field set + // Current PtrState should be empty + // Expected result: + // source = lhsState.source ? lhsState.source : rhsState.source + // sizes[i] = lhsState.sizes[i] (which should match rhsState.sizes[i]) + // offsets[i] = lhsState.offsets[i] - rhsState.offsets[i] + // strides[i] = lhsState.strides[i] - rhsState.strides[i] + LogicalResult visitOperandSub(arith::SubIOp subOp, PtrState &state, + const Location loc, OpBuilder &builder); + // Operand is the result of arith.muli. Process both arguments and insert any // arith.muli instruction as needed. // Main assumptions: diff --git a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h index 3f8c0b59cf..e32bf4eb84 100644 --- a/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h +++ b/third_party/tsingmicro/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.h @@ -825,9 +825,12 @@ struct MakeRangeConverter : public OpConversionPattern { ValueRange{init}, indexingMaps, getNParallelLoopsAttrs(1), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + Value start = rewriter.create( + loc, op.getStart()); // start Value index = nestedBuilder.create(loc, 0); Value res = nestedBuilder.create( - loc, type.getElementType(), index); + loc, type.getElementType(), + rewriter.create(loc, index, start)); nestedBuilder.create(loc, res); }); @@ -1311,6 +1314,58 @@ struct ReduceConverter : public OpConversionPattern { }); } + LogicalResult + convertMultiOpsReduction(triton::ReduceOp op, + typename triton::ReduceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto loc = op->getLoc(); + int64_t axis = op.getAxis(); + + SmallVector initTensors; + for (auto result : op->getResultTypes()) { + SmallVector shape = + isa(result) + ? SmallVector( + cast(result).getShape().begin(), + cast(result).getShape().end()) + : SmallVector{}; + Type type = isa(result) + ? cast(result).getElementType() + : result; + initTensors.push_back(rewriter.create(loc, shape, type)); + } + auto reduceOp = rewriter.create( + loc, op.getSrcs(), initTensors, SmallVector{axis}, + [&](OpBuilder &opBuilder, Location loc, ValueRange inputs) { + auto reduceBlock = op.getBody(); + IRMapping mapping; + mapping.map(reduceBlock->getArguments(), inputs); + for (auto &innerOp : reduceBlock->without_terminator()) { + opBuilder.clone(innerOp, mapping); + } + auto yield = reduceBlock->getTerminator(); + auto results = + llvm::map_to_vector(yield->getOperands(), [&](Value val) { + return mapping.lookup(val); + }); + opBuilder.create(loc, results); + }); + SmallVector results; + for (int i = 0; i < reduceOp->getNumResults(); i++) { + Value finalResult = + (isa(op->getResultTypes()[i])) + ? reduceOp->getResults()[i] + : rewriter + .create(loc, op->getResultTypes()[i], + reduceOp->getResults()[i]) + ->getResults()[0]; + results.push_back(finalResult); + } + rewriter.replaceOp(op, results); + return success(); + } + LogicalResult convertToLinalgReduce(triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, @@ -1322,11 +1377,13 @@ struct ReduceConverter : public OpConversionPattern { auto loc = op.getLoc(); auto reductionOps = getRedOps(op); + if (reductionOps.size() != 1) + return convertMultiOpsReduction(op, adaptor, rewriter); + // Reduction of arbitrary operations isn't supported because using the first // element across the reduction dimension requires us to iterate over a // subview that skips over each first element. - if (reductionOps.size() != 1 || - !isTritonAllowedReductionOp(reductionOps.front())) { + if (!isTritonAllowedReductionOp(reductionOps.front())) { return rewriter.notifyMatchFailure( op, "Only support lowering reduction with body " "containing 1 max(i/f) or addf."); @@ -2199,17 +2256,19 @@ struct ScanOpConverter Region &combineOp = op.getRegion(); bool reverse = op.getReverse(); - auto inputType = cast(inputs[0].getType()); - auto shape = inputType.getShape(); - + auto shape = cast(inputs[0].getType()).getShape(); SmallVector res(inputs.size()); std::transform(inputs.begin(), inputs.end(), res.begin(), [&](auto val) { + auto inputType = cast(val.getType()); + assert(inputType.getShape() == shape && + "All 1D input tensors must have the same shape"); return rewriter.create(loc, shape, inputType.getElementType()); }); SmallVector acc(inputs.size()); std::transform(inputs.begin(), inputs.end(), acc.begin(), [&](auto val) { + auto inputType = cast(val.getType()); return rewriter.create( loc, rewriter.getZeroAttr(inputType.getElementType())); }); @@ -2296,8 +2355,7 @@ struct ScanOpConverter Region &combineOp = op.getRegion(); bool reverse = op.getReverse(); - auto inputType = cast(inputs[0].getType()); - auto shape = inputType.getShape(); + auto shape = cast(inputs[0].getType()).getShape(); SmallVector resTypes; for (const auto &resTy : op.getResultTypes()) { @@ -2308,16 +2366,21 @@ struct ScanOpConverter // Initialize result tensors SmallVector res(inputs.size()); std::transform(inputs.begin(), inputs.end(), res.begin(), [&](auto val) { - return rewriter.create(loc, shape, + auto inputType = cast(val.getType()); + auto valShape = inputType.getShape(); + assert(shape[0] == valShape[0] && + "All input tensors must have the same leading dimension"); + return rewriter.create(loc, valShape, inputType.getElementType()); }); // Initialize accumulators as empty tensors of shape [1, ...] - SmallVector accShape({1}); - accShape.insert(accShape.end(), shape.begin() + 1, shape.end()); - SmallVector acc(inputs.size()); std::transform(inputs.begin(), inputs.end(), acc.begin(), [&](auto val) { + auto inputType = cast(val.getType()); + auto valShape = inputType.getShape(); + SmallVector accShape({1}); + accShape.insert(accShape.end(), shape.begin() + 1, shape.end()); return rewriter.create(loc, accShape, inputType.getElementType()); }); @@ -2369,7 +2432,7 @@ struct ScanOpConverter loc, RankedTensorType::get( sizeVal, - cast(resTypes[0]).getElementType()), + cast(val.getType()).getElementType()), val, dynOffsets, /*sizes*/ ValueRange(), /*strides*/ ValueRange(), /*static_offsets*/ diff --git a/third_party/tsingmicro/include/utils/ReduceScanCommon.h b/third_party/tsingmicro/include/utils/ReduceScanCommon.h index bbda913f59..8d7a44a16c 100644 --- a/third_party/tsingmicro/include/utils/ReduceScanCommon.h +++ b/third_party/tsingmicro/include/utils/ReduceScanCommon.h @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -114,9 +115,7 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { if (isa(type)) { auto temp = cast(type).getShape(); shape.insert(shape.end(), temp.begin(), temp.end()); - } else { - shape.push_back(1); - } + } // else shape is empty for scalar types auto &block = combineOp.getBlocks().front(); IRMapping map; // Map block arguments to the current inputs and accumulators. @@ -161,21 +160,41 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { Value lookupMappedValue(IRMapping &localMap, Value val, ArrayRef shape, OpBuilder &rewriter) const { - Value res = localMap.lookupOrNull(val); - if (!res) { - // If value is not found then it's an invariant defined in the outer - // region. We check if it has been already translated and add a splat - // operation if it hasn't. - res = invariantsMap.lookupOrNull(val); - if (!res) { - auto ip = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfterValue(val); - res = rewriter.create( - val.getLoc(), RankedTensorType::get(shape, val.getType()), val); - invariantsMap.map(val, res); - rewriter.restoreInsertionPoint(ip); - } + // First check in the local mapping + if (Value localMapped = localMap.lookupOrNull(val)) { + return localMapped; } + + // Delete invariantsMap lookup: val needs to transform differents shape + // tensor. For example, 64->32 needs tensor<32Xf32> , 32->16 needs + // tensor<16xf32>. + // TODO: Profile it to improve performance. Beacause aboved cases(64->32, + // 32->16) maybe create different buffers. + + // Then, if the value is of the expected shape, return it directly + Type valueType = val.getType(); + if ((!isa(valueType) && shape.empty()) || + (isa(valueType) && + cast(valueType).getShape() == shape)) { + // TODO: Check rank tensor when shape is empty. If shape is empty, should + // add extract op. + return val; + } + + // Finally, if value is not found then it's an invariant defined in the + // outer region. We check if it has been already translated and add a + // linalg.fill operation if value shape is different. + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfterValue(val); + auto ty = isa(valueType) + ? cast(valueType).getElementType() + : valueType; + auto empty = rewriter.create(val.getLoc(), shape, ty); + Value res = rewriter + .create(val.getLoc(), ValueRange{val}, + ValueRange{empty}) + .getResult(0); + rewriter.restoreInsertionPoint(ip); return res; } @@ -233,8 +252,9 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { auto outputShape = outputType.getShape(); SmallVector res(inputs.size()); std::transform(inputs.begin(), inputs.end(), res.begin(), [&](auto val) { + auto valType = cast(val.getType()); return rewriter.create(loc, outputShape, - inputType.getElementType()); + valType.getElementType()); }); SmallVector loops; @@ -263,10 +283,11 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { inputReassociation] = tensorTransform(rewriter, loc, loopIndices, inputType, axis); for (size_t i = 0; i < inputs.size(); ++i) { + auto valueType = cast(inputs[i].getType()); auto extractTensor = rewriter.create( loc, RankedTensorType::get(inputStaticSize, - inputType.getElementType()), + valueType.getElementType()), inputs[i], inputDynamicIndices, /*sizes*/ ValueRange(), /*strides*/ ValueRange(), SmallVector(inputShape.size(), ShapedType::kDynamic), @@ -283,8 +304,9 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { tensorTransform(rewriter, loc, loopIndices, outputType, axis); for (size_t i = 0; i < res.size(); ++i) { - auto targetType = RankedTensorType::get(outputStaticSize, - inputType.getElementType()); + auto resType = cast(res[i].getType()); + auto targetType = + RankedTensorType::get(outputStaticSize, resType.getElementType()); // {shape[axis],shape[axis+1],..shape[rank]} -> // {1,1,..shape[axis],shape[axis+1],..shape[rank]} Value reshaped = rewriter.create( diff --git a/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp index c1aa8a6428..23439325b4 100644 --- a/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp +++ b/third_party/tsingmicro/lib/Analysis/MaskAnalysis.cpp @@ -50,6 +50,8 @@ LogicalResult MaskState::parse(Value operand, const Location loc, return this->parseLoopIterArg(operand, loc, builder); } else if (auto op = operand.getDefiningOp()) { return this->parseExtSI(op, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return this->parseSub(op, loc, builder); } else { return failure(); } @@ -193,6 +195,37 @@ LogicalResult MaskState::addStateScalar(const MaskState &state, return success(); } +LogicalResult MaskState::subStateScalar(const MaskState &state, + const OpFoldResult scalar, Location loc, + OpBuilder &builder) { + start = subOFRs(state.start, scalar, loc, builder); + end = subOFRs(state.end, scalar, loc, builder); + dims = state.dims; + return success(); +} + +LogicalResult MaskState::subStates(const MaskState &lhsState, + const MaskState &rhsState, Location loc, + OpBuilder &builder) { + if (lhsState.scalar && rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) << "Unexpected case where both lhs and rhs are scalars"; + return failure(); + } + + if (!lhsState.scalar && !rhsState.scalar) { + InFlightDiagnostic diag = + emitError(loc) + << "Unsupported scenario where neither lhs nor rhs is a scalar"; + return failure(); + } + + if (lhsState.scalar) + return subStateScalar(rhsState, lhsState.scalar, loc, builder); + else + return subStateScalar(lhsState, rhsState.scalar, loc, builder); +} + LogicalResult MaskState::addStates(const MaskState &lhsState, const MaskState &rhsState, Location loc, OpBuilder &builder) { @@ -311,6 +344,21 @@ LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc, return this->addStates(lhsState, rhsState, loc, builder); } +LogicalResult MaskState::parseSub(arith::SubIOp subOp, const Location loc, + OpBuilder &builder) { + assert(this->isEmpty()); + + MaskState lhsState; + if (failed(lhsState.parse(subOp.getLhs(), loc, builder))) + return failure(); + + MaskState rhsState; + if (failed(rhsState.parse(subOp.getRhs(), loc, builder))) + return failure(); + + return this->subStates(lhsState, rhsState, loc, builder); +} + LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc, OpBuilder &builder) { assert(this->isEmpty()); diff --git a/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp index 3a9df9b9e7..bb1681eb9a 100644 --- a/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp +++ b/third_party/tsingmicro/lib/AnalysisStructured/PtrAnalysis.cpp @@ -110,9 +110,9 @@ LogicalResult PtrState::addState(const PtrState &lhsState, auto loc = op->getLoc(); if (lhsState.source && rhsState.source) { - op->emitRemark( + LLVM_DEBUG(op->emitRemark( "PtrAnalysis: do not support adding two pointer states that both " - "have base pointers"); + "have base pointers")); return failure(); } @@ -140,8 +140,9 @@ LogicalResult PtrState::addState(const PtrState &lhsState, // AddPtr where both lhs and rhs containing modulo operators not supported if (lhsState.hasModulo() && rhsState.hasModulo()) { - op->emitRemark("PtrAnalysis: do not support adding two pointer states " - "that both have modulo"); + LLVM_DEBUG( + op->emitRemark("PtrAnalysis: do not support adding two pointer states " + "that both have modulo")); return failure(); } @@ -188,17 +189,17 @@ LogicalResult PtrState::addState(const PtrState &lhsState, } else if (i == 0 && lhs->getRank() == 2 && rhs->scalar) { shape.push_back(lhs->shape[1]); shape.push_back(lhs->shape[0]); - op->emitWarning( + LLVM_DEBUG(op->emitWarning( "PtrAnalysis: allowing adding pointer state with modulo in dim 0 to " "another pointer state with offset in dim 0.\nPlease verify the " "operand that contains a scalar is meant to increment pointers in " "dim1. If that is not the case it WILL LEAD TO WRONG COMPILATION " "RESULTS.\n\nTo avoid this warning, use expand_dims (instead of " - "splat) to explicitly specify which dimension contains the scalar."); + "splat) to explicitly specify which dimension contains the scalar.")); break; } else { - op->emitRemark( - "PtrAnalysis: do not support adding to operand with modulo"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: do not support adding to operand with modulo")); return failure(); } } @@ -215,15 +216,15 @@ void PtrState::dump() const { llvm::dbgs() << "scalar: " << scalar << "\n"; } - llvm::dbgs() << "offsets: "; + llvm::dbgs() << "offsets:\n"; llvm::interleave(offsets, llvm::dbgs(), "\n"); - llvm::dbgs() << "\nstrides: "; + llvm::dbgs() << "\nstrides:\n"; llvm::interleave(strides, llvm::dbgs(), "\n"); - llvm::dbgs() << "\nsizes: "; + llvm::dbgs() << "\nsizes:\n"; llvm::interleave(sizes, llvm::dbgs(), "\n"); - llvm::dbgs() << "\nshape: "; + llvm::dbgs() << "\nshape:\n"; llvm::interleave(shape, llvm::dbgs(), "\n"); - llvm::dbgs() << "\norder: "; + llvm::dbgs() << "\norder:\n"; llvm::interleave(order, llvm::dbgs(), "\n"); llvm::dbgs() << "\n"; } @@ -238,15 +239,16 @@ LogicalResult PtrState::mulState(const PtrState &lhsState, // neither lhs nor rhs should have source, since multiplying base pointer // does not make sense if (lhsState.source && rhsState.source) { - op->emitRemark("PtrAnalysis: do not support multiplying base pointers"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: do not support multiplying base pointers")); return failure(); } // currently do not support both tensors are effectively non-scalar if (!lhsState.scalar && !rhsState.scalar) { - op->emitRemark( + LLVM_DEBUG(op->emitRemark( "PtrAnalysis: only support multiplying pointer states when one of " - "them represent a scalar"); + "them represent a scalar")); return failure(); } @@ -276,15 +278,79 @@ LogicalResult PtrState::mulState(const PtrState &lhsState, } if (rhs->hasModulo()) { - op->emitRemark( + LLVM_DEBUG(op->emitRemark( "PtrAnalysis: do not support multiplying pointer states that has " - "modulos"); + "modulos")); return failure(); } return success(); } +LogicalResult PtrState::subState(const PtrState &lhsState, + const PtrState &rhsState, Operation *op, + OpBuilder &builder) { + assert(isEmpty() && lhsState.getRank() == rhsState.getRank()); + auto loc = op->getLoc(); + + if (lhsState.source && rhsState.source) { + if (lhsState.source != rhsState.source) { + op->emitRemark("PtrAnalysis: subtracting pointers from different bases " + "is not supported"); + return failure(); + } + + if (lhsState.scalar && rhsState.scalar) { + auto subOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = subOp.getResult(); + } else if (lhsState.scalar) { + scalar = lhsState.scalar; + } else if (rhsState.scalar) { + auto zero = builder.create( + loc, 0, rhsState.scalar.getType()); + auto negOp = builder.create(loc, zero, rhsState.scalar); + scalar = negOp.getResult(); + } + + source = nullptr; + return success(); + } + + if (!lhsState.source && rhsState.source) { + op->emitRemark("PtrAnalysis: scalar minus pointer is not meaningful"); + return failure(); + } + + source = lhsState.source ? lhsState.source : rhsState.source; + + if (lhsState.scalar && rhsState.scalar) { + auto subOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = subOp.getResult(); + } else if (lhsState.getRank() == 0) { // both lhs and rhs are scalars + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + + for (uint64_t i = 0; i < lhsState.getRank(); i++) { + auto newOffset = + subOFRs(lhsState.offsets[i], rhsState.offsets[i], loc, builder); + offsets.push_back(newOffset); + + auto newStride = + subOFRs(lhsState.strides[i], rhsState.strides[i], loc, builder); + strides.push_back(newStride); + + sizes.push_back(lhsState.sizes[i]); + } + + for (uint64_t i = 0; i < lhsState.getRank(); i++) { + shape.push_back(lhsState.shape[i]); + } + + return success(); +} + tts::MakeTensorPtrOp PtrState::createTTSMakeTensorPtrOp(OpBuilder &builder, Location loc) { SmallVector staticSizes; @@ -319,14 +385,44 @@ LogicalResult PtrAnalysis::visitOperandAdd(arith::AddIOp addOp, PtrState &state, // Checking for higher dimension is done in addState below if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || (rhsState.getRank() == 1 && rhsState.hasModulo())) { - addOp->emitRemark( - "PtrAnalysis: do not support this pattern: a + arange(0, K) % M"); + LLVM_DEBUG(addOp->emitRemark( + "PtrAnalysis: do not support this pattern: a + arange(0, K) % M")); return failure(); } return state.addState(lhsState, rhsState, addOp, builder); } +LogicalResult PtrAnalysis::visitOperandSub(arith::SubIOp subOp, PtrState &state, + const Location loc, + OpBuilder &builder) { + PtrState lhsState; + if (visitOperand(subOp.getLhs(), lhsState, loc, builder).failed()) { + return failure(); + } + + PtrState rhsState; + if (visitOperand(subOp.getRhs(), rhsState, loc, builder).failed()) { + return failure(); + } + + if (lhsState.hasModulo() || rhsState.hasModulo()) { + LLVM_DEBUG( + subOp->emitRemark("PtrAnalysis: do not support modulo for subi op\n")); + return failure(); + } + + // Checking for higher dimension is done in subState below + if ((lhsState.getRank() == 1 && lhsState.hasModulo()) || + (rhsState.getRank() == 1 && rhsState.hasModulo())) { + subOp->emitRemark( + "PtrAnalysis: do not support this pattern: a - arange(0, K) % M"); + return failure(); + } + + return state.subState(lhsState, rhsState, subOp, builder); +} + LogicalResult PtrAnalysis::visitOperandMul(arith::MulIOp mulOp, PtrState &state, const Location loc, OpBuilder &builder) { @@ -354,8 +450,9 @@ LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, } if (!rhsState.scalar) { - remOp->emitRemark("PtrAnalysis: only support cases when rhs of remainder " - "contains scalar"); + LLVM_DEBUG(remOp->emitRemark( + "PtrAnalysis: only support cases when rhs of remainder " + "contains scalar")); return failure(); } @@ -367,8 +464,8 @@ LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, // would have already populated the modulo states after visiting the lhs. // Assert that all the modulo states are empty. if (state.hasModulo()) { - remOp->emitRemark( - "PtrAnalysis: do not support multiple modulo within an expression"); + LLVM_DEBUG(remOp->emitRemark( + "PtrAnalysis: do not support multiple modulo within an expression")); return failure(); } @@ -392,13 +489,13 @@ LogicalResult PtrAnalysis::visitOperandRem(arith::RemSIOp remOp, } else if (shape[1] == 1) { state.shape[0] = rhsState.scalar; } else { - remOp->emitRemark( + LLVM_DEBUG(remOp->emitRemark( "PtrAnalysis: taking modulo on a 2D tensor with no singleton " - "dimension not supported"); + "dimension not supported")); return failure(); } } else { - remOp->emitRemark("PtrAnalysis: unsupported modulo pattern"); + LLVM_DEBUG(remOp->emitRemark("PtrAnalysis: unsupported modulo pattern")); return failure(); } return success(); @@ -456,9 +553,9 @@ PtrAnalysis::visitOperandExpandDims(triton::ExpandDimsOp expandDimsOp, state.shape.insert(state.shape.begin() + axis, builder.getIndexAttr(0)); if (state.hasModulo() && state.getRank() > 2) { - expandDimsOp->emitRemark( + LLVM_DEBUG(expandDimsOp->emitRemark( "PtrAnalysis: unsupported scenario where expand_dims result " - "has modulo and rank > 2"); + "has modulo and rank > 2")); return failure(); } @@ -475,7 +572,8 @@ PtrAnalysis::visitOperandBroadcast(triton::BroadcastOp broadcastOp, auto dst = broadcastOp.getResult(); if (!isa(src.getType())) { - broadcastOp->emitRemark("PtrAnalysis: Unsupported broadcast source type"); + LLVM_DEBUG(broadcastOp->emitRemark( + "PtrAnalysis: Unsupported broadcast source type")); return failure(); } @@ -523,7 +621,7 @@ LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, state.shape.push_back(builder.getIndexAttr(0)); } } else { - splatOp->emitRemark("PtrAnalysis: unsupported splat pattern"); + LLVM_DEBUG(splatOp->emitRemark("PtrAnalysis: unsupported splat pattern")); return failure(); } @@ -533,8 +631,9 @@ LogicalResult PtrAnalysis::visitOperandSplat(triton::SplatOp splatOp, state.offsets[0] = state.scalar; if (state.hasModulo() && state.getRank() > 2) { - splatOp->emitRemark("PtrAnalysis: unsupported scenario where splat result " - "has modulo and rank > 2"); + LLVM_DEBUG(splatOp->emitRemark( + "PtrAnalysis: unsupported scenario where splat result " + "has modulo and rank > 2")); return failure(); } @@ -627,8 +726,8 @@ PtrAnalysis::visitOperandMakeTensorPtr(triton::MakeTensorPtrOp makeTPtrOp, state.source = makeTPtrOp.getBase(); if (makeTPtrOp.getOrder().empty()) { - makeTPtrOp->emitRemark( - "PtrAnalysis: expect tt.make_tensor_ptr to have order field set"); + LLVM_DEBUG(makeTPtrOp->emitRemark( + "PtrAnalysis: expect tt.make_tensor_ptr to have order field set")); return failure(); } @@ -671,9 +770,9 @@ LogicalResult PtrAnalysis::visitOperandForOp(scf::ForOp forOp, Value operand, auto newState = getLoopResultPtrState(forOp, index); if (failed(newState)) { - forOp.emitError( + LLVM_DEBUG(forOp.emitError( "Rewrite for-op failed. Could not find PtrState returned by " - "the loop."); + "the loop.")); return failure(); } @@ -731,7 +830,7 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, state.source = selectOp.getResult(); return success(); } else { - op->emitRemark("Unexpected operand defining operation"); + LLVM_DEBUG(op->emitRemark("Unexpected operand defining operation")); return failure(); } } else { @@ -744,6 +843,8 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, return visitOperandAdd(op, state, loc, builder); } else if (auto op = operand.getDefiningOp()) { return visitOperandMul(op, state, loc, builder); + } else if (auto op = operand.getDefiningOp()) { + return visitOperandSub(op, state, loc, builder); } else if (auto op = operand.getDefiningOp()) { return visitOperandMakeRange(op, state, loc, builder); } else if (auto op = operand.getDefiningOp()) { @@ -773,9 +874,10 @@ LogicalResult PtrAnalysis::visitOperand(Value operand, PtrState &state, state = knownPtrs[operand]; return success(); } else { - llvm::dbgs() << "PtrAnalysis: encountered addptr operand produced by an " - "unsupported operation\n"; - operand.dump(); + LLVM_DEBUG(llvm::dbgs() + << "PtrAnalysis: encountered addptr operand produced by an " + "unsupported operation\n"; + operand.dump()); return failure(); } } @@ -851,7 +953,8 @@ LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) { PtrState state; if (visitOperand(op->getOperand(0), state, loc, builder).failed()) { - op->emitRemark("PtrAnalysis: Failed to analyze ptr of tt.advance"); + LLVM_DEBUG( + op->emitRemark("PtrAnalysis: Failed to analyze ptr of tt.advance")); return failure(); } assert(state.isBlockPtr() && @@ -1010,9 +1113,9 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { // considered structured by PtrAnalysis, failing to retrieve the PtrState // should not fail the rewrite process. // We emit an error for diagnostics and debugging purposes. - op->emitWarning( + LLVM_DEBUG(op->emitWarning( "Rewrite for-op failed. Could not find PtrState for iter-arg index " + - std::to_string(i)); + std::to_string(i))); continue; } @@ -1055,8 +1158,8 @@ LogicalResult PtrAnalysis::rewriteForOp(scf::ForOp op) { // Recursively rewrite the inner ops if (rewriteOp(op).failed()) { - op->emitRemark( - "PtrAnalysis: update loop body failed when rewriting for op"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: update loop body failed when rewriting for op")); return failure(); } @@ -1067,13 +1170,20 @@ LogicalResult PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { auto tritonValue = op->getOperand(0); + OpBuilder builder(op); + // If this triton value isn't known, it means PtrAnalysis has failed to // analyze this pointer. In such cases, simply remap all uses of the // structured value back to its original triton value. if (!knownPtrs.contains(tritonValue)) { - op.emitRemark( - "Rewrite GetStructuredStateOp failed. Could not find PtrState."); - op.getResult(0).replaceAllUsesWith(tritonValue); + LLVM_DEBUG(op.emitRemark( + "Rewrite GetStructuredStateOp failed. Could not find PtrState.")); + auto numResults = op.getNumResults(); + SmallVector replacements( + numResults, builder.create(op.getLoc(), + builder.getIndexAttr(0))); + replacements.front() = tritonValue; + op.getResults().replaceAllUsesWith(replacements); return failure(); } @@ -1082,7 +1192,6 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) { ptrMap.contains(tritonValue) ? ptrMap.lookup(tritonValue) : tritonValue; SmallVector replacements{remappedValue}; - OpBuilder builder(op); if (state.getRank() == 0) { // For scalar pointers, the scalar contains the offset and is the only @@ -1133,14 +1242,16 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, auto loc = op.getLoc(); if (!ptr) { - op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " - "loadOp cannot be rewritten"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: pointer is not replace with tts.make_tptr so " + "loadOp cannot be rewritten")); return failure(); } auto ptrType = dyn_cast(ptr.getType()); if (ptrType && !isa(ptrType.getPointeeType())) { - op->emitRemark("PtrAnalysis: scalar loadOp will not be rewritten"); + LLVM_DEBUG( + op->emitRemark("PtrAnalysis: scalar loadOp will not be rewritten")); return failure(); } @@ -1153,7 +1264,7 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, // are moving. if (mask) { if (mstate.parse(mask, loc, builder).failed()) { - op->emitRemark("MaskAnalysis failed"); + LLVM_DEBUG(op->emitRemark("MaskAnalysis failed")); return failure(); } dims = mstate.dims; @@ -1171,8 +1282,8 @@ LogicalResult PtrAnalysis::rewriteLoadOp(triton::LoadOp op, scalarOther = triton::getScalarValue(other, loc, builder); if (!scalarOther) { - op->emitRemark("other value used in masked load produced by " - "unsupported instruction"); + LLVM_DEBUG(op->emitRemark("other value used in masked load produced by " + "unsupported instruction")); return failure(); } } @@ -1277,14 +1388,16 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, auto loc = op.getLoc(); if (!ptr) { - op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " - "storeOp cannot be rewritten"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: pointer is not replace with tts.make_tptr so " + "storeOp cannot be rewritten")); return failure(); } auto ptrType = dyn_cast(ptr.getType()); if (ptrType && !isa(ptrType.getPointeeType())) { - op->emitRemark("PtrAnalysis: scalar storeOp will not be rewritten"); + LLVM_DEBUG( + op->emitRemark("PtrAnalysis: scalar storeOp will not be rewritten")); return failure(); } @@ -1297,7 +1410,7 @@ LogicalResult PtrAnalysis::rewriteStoreOp(triton::StoreOp op, // are moving. if (mask) { if (mstate.parse(mask, loc, builder).failed()) { - op->emitRemark("MaskAnalysis failed"); + LLVM_DEBUG(op->emitRemark("MaskAnalysis failed")); return failure(); } dims = mstate.dims; @@ -1329,14 +1442,16 @@ LogicalResult PtrAnalysis::rewriteAtomicRMWOp(triton::AtomicRMWOp op, auto loc = op.getLoc(); if (!ptr) { - op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " - "AtomicCASOp cannot be rewritten"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: pointer is not replace with tts.make_tptr so " + "AtomicCASOp cannot be rewritten")); return failure(); } auto ptrType = dyn_cast(ptr.getType()); if (ptrType && !isa(ptrType.getPointeeType())) { - op->emitRemark("PtrAnalysis: scalar AtomicCASOp will not be rewritten"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: scalar AtomicCASOp will not be rewritten")); return failure(); } @@ -1349,7 +1464,7 @@ LogicalResult PtrAnalysis::rewriteAtomicRMWOp(triton::AtomicRMWOp op, // are moving. if (mask) { if (mstate.parse(mask, loc, builder).failed()) { - op->emitRemark("MaskAnalysis failed"); + LLVM_DEBUG(op->emitRemark("MaskAnalysis failed")); return failure(); } dims = mstate.dims; @@ -1377,14 +1492,16 @@ LogicalResult PtrAnalysis::rewriteAtomicCASOp(triton::AtomicCASOp op) { auto loc = op.getLoc(); if (!ptr) { - op->emitRemark("PtrAnalysis: pointer is not replace with tts.make_tptr so " - "AtomicCASOp cannot be rewritten"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: pointer is not replace with tts.make_tptr so " + "AtomicCASOp cannot be rewritten")); return failure(); } auto ptrType = dyn_cast(ptr.getType()); if (ptrType && !isa(ptrType.getPointeeType())) { - op->emitRemark("PtrAnalysis: scalar AtomicCASOp will not be rewritten"); + LLVM_DEBUG(op->emitRemark( + "PtrAnalysis: scalar AtomicCASOp will not be rewritten")); return failure(); } @@ -1417,53 +1534,60 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { return TypeSwitch(op) .Case([&](auto addptr) { if (rewriteAddptrOp(addptr).failed()) { - addptr->emitRemark("PtrAnalysis: Failed to rewrite AddPtrOp"); + LLVM_DEBUG( + addptr->emitRemark("PtrAnalysis: Failed to rewrite AddPtrOp")); } return WalkResult::advance(); }) .Case([&](auto bitcast) { if (rewriteBitcastOp(bitcast).failed()) { - bitcast->emitRemark("PtrAnalysis: Failed to rewrite BitcastOp"); + LLVM_DEBUG(bitcast->emitRemark( + "PtrAnalysis: Failed to rewrite BitcastOp")); } return WalkResult::advance(); }) .Case([&](auto maketptr) { if (rewriteMakeTensorPtrOp(maketptr).failed()) { - maketptr->emitRemark( - "PtrAnalysis: Failed to rewrite MakeTensorPtrOp"); + LLVM_DEBUG(maketptr->emitRemark( + "PtrAnalysis: Failed to rewrite MakeTensorPtrOp")); } return WalkResult::advance(); }) .Case([&](auto advance) { if (rewriteAdvanceOp(advance).failed()) { - advance->emitRemark("PtrAnalysis: Failed to rewrite AdvanceOp"); + LLVM_DEBUG(advance->emitRemark( + "PtrAnalysis: Failed to rewrite AdvanceOp")); } return WalkResult::advance(); }) .Case([&](auto load) { if (rewriteLoadOp(load, useUnsafeMask).failed()) { - load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp"); + LLVM_DEBUG( + load->emitRemark("PtrAnalysis: Failed to rewrite LoadOp")); return WalkResult::advance(); } return WalkResult::skip(); }) .Case([&](auto store) { if (rewriteStoreOp(store, useUnsafeMask).failed()) { - store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp"); + LLVM_DEBUG( + store->emitRemark("PtrAnalysis: Failed to rewrite StoreOp")); return WalkResult::advance(); } return WalkResult::skip(); }) .Case([&](auto atomicRMW) { if (rewriteAtomicRMWOp(atomicRMW, useUnsafeMask).failed()) { - atomicRMW->emitRemark("PtrAnalysis: Failed to rewrite AtomicRMWOp"); + LLVM_DEBUG(atomicRMW->emitRemark( + "PtrAnalysis: Failed to rewrite AtomicRMWOp")); return WalkResult::advance(); } return WalkResult::skip(); }) .Case([&](auto atomicCAS) { if (rewriteAtomicCASOp(atomicCAS).failed()) { - atomicCAS->emitRemark("PtrAnalysis: Failed to rewrite AtomicCASOp"); + LLVM_DEBUG(atomicCAS->emitRemark( + "PtrAnalysis: Failed to rewrite AtomicCASOp")); return WalkResult::advance(); } return WalkResult::skip(); @@ -1474,7 +1598,8 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { // that the the walk does not visit the for-op's child operations // the second time. if (rewriteForOp(forOp).failed()) { - forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp"); + LLVM_DEBUG( + forOp->emitRemark("PtrAnalysis: Failed to rewrite ForOp")); } return WalkResult::skip(); }) @@ -1498,8 +1623,9 @@ LogicalResult PtrAnalysis::rewriteOp(Operation *rootOp, bool useUnsafeMask) { getStateOp->getLoc(), b))) { knownPtrs[tritonValue] = state; } else { - getStateOp->emitRemark("PtrAnalysis: Failed to populate ptr " - "state for tensor of indices"); + LLVM_DEBUG(getStateOp->emitRemark( + "PtrAnalysis: Failed to populate ptr " + "state for tensor of indices")); } } diff --git a/third_party/tsingmicro/lib/Conversion/AllocateSharedMemory/AllocateSharedMemoryPass.cpp b/third_party/tsingmicro/lib/Conversion/AllocateSharedMemory/AllocateSharedMemoryPass.cpp index e50aa9de6a..ffe50045ed 100644 --- a/third_party/tsingmicro/lib/Conversion/AllocateSharedMemory/AllocateSharedMemoryPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/AllocateSharedMemory/AllocateSharedMemoryPass.cpp @@ -32,7 +32,9 @@ struct AllocateSharedMemory assert(op && "Value has no defining op"); if (isa(op)) return op; - assert(isa(op) || isa(op)); + // Memref op which has result: ViewLikeOpInterface. Eg: + // memref::ExpandShapeOp + assert(isa(op)); return findAlignmentRestrictOpOperandBuffer(op->getOperand(0)); } diff --git a/third_party/tsingmicro/lib/Conversion/LegalizeTensorFormLoops/LegalizeTensorFormLoops.cpp b/third_party/tsingmicro/lib/Conversion/LegalizeTensorFormLoops/LegalizeTensorFormLoops.cpp index b85733aa72..aaca9bf509 100644 --- a/third_party/tsingmicro/lib/Conversion/LegalizeTensorFormLoops/LegalizeTensorFormLoops.cpp +++ b/third_party/tsingmicro/lib/Conversion/LegalizeTensorFormLoops/LegalizeTensorFormLoops.cpp @@ -35,11 +35,20 @@ struct ForOpRewrite : public OpRewritePattern { auto itArg = forOp.getRegionIterArgs()[op.index()]; if (!isa(val.getType()) || val == itArg) continue; - auto copyOp = dyn_cast(val.getDefiningOp()); - // TODO: Use BufferizableOpInterface to analyze whether the operand is - // equivalent to the corresponding iter bbArg. - if (!copyOp || copyOp.getOutputs()[0] != itArg) { + bool insertCopy = false; + auto defOp = val.getDefiningOp(); + if (defOp) { + // TODO: Use BufferizableOpInterface to analyze whether the operand is + // equivalent to the corresponding iter bbArg. + auto copyOp = dyn_cast(defOp); + insertCopy = !copyOp || copyOp.getOutputs()[0] != itArg; + } else { + // BlockArgument && val != itArg + insertCopy = true; + } + + if (insertCopy) { auto reduceVal = rewriter.create(forOp.getLoc(), val, itArg); yieldOp->setOperand(op.index(), reduceVal->getResult(0)); diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp index 495f9fc336..9a6bdde414 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMK.cpp @@ -690,7 +690,7 @@ struct LinalgReduceToMKReduceConversion ? SmallVector{1, 1, inputShape4D[2], 1} : SmallVector{1, 1, inputShape4D[1], inputShape4D.back()}; - if (4 <= lastDim && lastDim <= 128 && !lastDimReduce) + if (4 <= lastDim && lastDim <= alignBase && !lastDimReduce) return src; if (!lastDimReduce && lastDim > alignBase) { @@ -847,7 +847,13 @@ struct ScalarGlobalStoreRewrite : public OpRewritePattern { auto val = op.getValue(); if (!isa(val.getType())) { - val = rewriter.create(op.getLoc(), val); + // NOTE: tensor::FromElementsOp will optimize to arith::ConstantOp which + // has dense constant attribute. + auto empty = rewriter.create( + op.getLoc(), SmallVector{1}, val.getType()); + + val = rewriter.create(op.getLoc(), val, empty, + ValueRange{zero}); } auto storeOp = rewriter.replaceOpWithNewOp( @@ -1749,10 +1755,10 @@ struct PowFOpRewrite : public OpRewritePattern { ->getResult(0); // Convert boolean masks to integer for bitwise operations - auto isBaseNegativeInt = buildLinalgElementwise( - rewriter, loc, intResultType, ValueRange{isBaseNegative}); - auto isIntegerLikeExponentInt = buildLinalgElementwise( - rewriter, loc, intResultType, ValueRange{isIntegerLikeExponent}); + auto isBaseNegativeInt = rewriter.create( + loc, intResultType, ValueRange{isBaseNegative}); + auto isIntegerLikeExponentInt = rewriter.create( + loc, intResultType, ValueRange{isIntegerLikeExponent}); auto canTakeAbsolute = buildLinalgElementwise( rewriter, loc, intResultType, ValueRange{isBaseNegativeInt, isIntegerLikeExponentInt}); @@ -1760,16 +1766,16 @@ struct PowFOpRewrite : public OpRewritePattern { // Check if integer-like exponent is odd : b % 2 != 0 auto isOddIndicator = computeOddFloatIndicator(rewriter, truncatedExponent, loc); - auto isOddIndicatorInt = buildLinalgElementwise( - rewriter, loc, intResultType, ValueRange{isOddIndicator}); + auto isOddIndicatorInt = rewriter.create( + loc, intResultType, ValueRange{isOddIndicator}); // Final condition: a < 0 & b is IntergerLike & b % 2 != 0 auto isResultNegative = buildLinalgElementwise( rewriter, loc, intResultType, ValueRange{canTakeAbsolute, isOddIndicatorInt}); - return buildLinalgElementwise( - rewriter, loc, resultType, ValueRange{isResultNegative}); + return rewriter.create(loc, resultType, + ValueRange{isResultNegative}); } LogicalResult matchAndRewrite(linalg::GenericOp op, @@ -3172,18 +3178,24 @@ struct ReduceOpToElementwiseOpConverter return finalResult; } + bool isInputsIncludeI1Type(ValueRange inputs) const { + return llvm::any_of(inputs, [](Value input) { + auto inputType = dyn_cast(input.getType()); + return inputType && inputType.getElementType().isInteger(1); + }); + } + SmallVector lower1DInput(ValueRange inputs, linalg::ReduceOp op, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto inputType = cast(inputs[0].getType()); - auto shape = inputType.getShape(); + auto leadingInputType = cast(inputs[0].getType()); + auto shape = leadingInputType.getShape(); int32_t tileSize = shape[0] > 64 ? 64 : shape[0]; SmallVector lastRes(inputs.size()); - assert(inputs.size() == 1 && "Expected only one input for 1D reduction"); // NOTE: Use scf.while may exist dynamic shape problem. // shape > 64: tiling n * 64, reduction n dim @@ -3192,25 +3204,37 @@ struct ReduceOpToElementwiseOpConverter // Reshape to 2D tensor with shape [tile, N / tile] // Call lowering leading dimension reduction SmallVector tiledShape = {shape[0] >> 6, 64}; - Value reshape = rewriter.create( - loc, RankedTensorType::get(tiledShape, inputType.getElementType()), - inputs[0], ArrayRef{{0, 1}}); - lastRes = lowerLeadingDimension({reshape}, op, rewriter); + SmallVector reshapedInputs; + for (auto input : inputs) { + auto inputType = cast(input.getType()); + Value reshape = rewriter.create( + loc, RankedTensorType::get(tiledShape, inputType.getElementType()), + input, ArrayRef{{0, 1}}); + reshapedInputs.push_back(reshape); + } + lastRes = lowerLeadingDimension(reshapedInputs, op, rewriter); } else { lastRes = inputs; } - if (inputType.getElementType().isInteger(1)) { + if (inputs.size() == 1 && leadingInputType.getElementType().isInteger(1)) { // TODO: Can optimized to only 8 elements - return lowerBool1DInput(rewriter, loc, inputType.getElementType(), + return lowerBool1DInput(rewriter, loc, leadingInputType.getElementType(), lastRes, op); } + assert(!isInputsIncludeI1Type(inputs) && + "I1 type inputs not supported for multi-op reductions: " + "byte-unaligned element access requires special handling"); + // TODO: Implement i1 type support for reduction operations by handling + // byte-unaligned element access in address calculation(lowerBool1DInput). + Region &combineOp = op.getRegion(); auto createExtractSliceOp = [&](Value val, SmallVector static_offsets, SmallVector static_size, SmallVector static_stride) { + auto inputType = cast(val.getType()); return rewriter.create( loc, RankedTensorType::get(static_size, inputType.getElementType()), val, ValueRange(), /*sizes*/ ValueRange(), @@ -3220,17 +3244,23 @@ struct ReduceOpToElementwiseOpConverter for (int32_t i = tileSize >> 1; i >= 1; i >>= 1) { auto idx = rewriter.create(loc, i); - auto curRes = createExtractSliceOp( - lastRes.front(), SmallVector{0}, SmallVector{i}, - SmallVector{1}); - auto RHS = createExtractSliceOp(lastRes.front(), SmallVector{i}, - SmallVector{i}, - SmallVector{1}); - lastRes = accumulate({RHS}, {curRes}, combineOp, rewriter); + SmallVector binaryInputs, binaryAcc; + for (auto &val : lastRes) { + auto curRes = createExtractSliceOp(val, SmallVector{0}, + SmallVector{i}, + SmallVector{1}); + auto RHS = createExtractSliceOp(val, SmallVector{i}, + SmallVector{i}, + SmallVector{1}); + binaryInputs.push_back(RHS); + binaryAcc.push_back(curRes); + } + lastRes = accumulate(binaryInputs, binaryAcc, combineOp, rewriter); } // Collapse the shape of the last result to a scalar tensor std::transform( lastRes.begin(), lastRes.end(), lastRes.begin(), [&](auto val) { + auto inputType = cast(val.getType()); return rewriter.create( loc, RankedTensorType::get({}, inputType.getElementType()), val, ArrayRef{}); @@ -3245,8 +3275,8 @@ struct ReduceOpToElementwiseOpConverter auto loc = op.getLoc(); Region &combineOp = op.getRegion(); - auto inputType = cast(inputs[0].getType()); - auto shape = inputType.getShape(); + auto leadingInputType = cast(inputs[0].getType()); + auto shape = leadingInputType.getShape(); // Initialize accumulators as empty tensors of shape [shape[1], ..] SmallVector accShape(shape.begin() + 1, shape.end()); @@ -3273,6 +3303,7 @@ struct ReduceOpToElementwiseOpConverter } std::transform(inputs.begin(), inputs.end(), acc.begin(), [&](auto val) { + auto inputType = cast(val.getType()); auto extract_tensor = rewriter.create( loc, RankedTensorType::get(sizeVal, inputType.getElementType()), val, /*offset*/ ValueRange(), /*sizes*/ ValueRange(), @@ -3307,6 +3338,7 @@ struct ReduceOpToElementwiseOpConverter std::transform( inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + auto inputType = cast(val.getType()); auto extract_tensor = b.create( loc, RankedTensorType::get(sizeVal, inputType.getElementType()), @@ -3503,6 +3535,99 @@ struct I1ExtSIOpRewrite : public OpRewritePattern { } }; +struct I1ToF32Rewrite : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + + auto regionOps = triton::getRegionOps(op); + + if (regionOps.size() != 1 || !isa(regionOps.front())) + return rewriter.notifyMatchFailure(op, "only rewrite i1 to f32 op\n"); + + auto siToFP = cast(regionOps.front()); + + if (!siToFP->getOperandTypes()[0].isInteger(1) || + !siToFP->getResultTypes()[0].isF32()) + return rewriter.notifyMatchFailure(op, "only rewrite i1 to f32 op\n"); + + Location loc = op.getLoc(); + + auto input = op.getInputs()[0]; + auto inputType = cast(input.getType()); + auto resultType = cast(op->getResultTypes()[0]); + + auto f32Type = + RankedTensorType::get(inputType.getShape(), rewriter.getF32Type()); + auto empty = rewriter.create(loc, resultType.getShape(), + rewriter.getF32Type()); + + auto f32Result = + rewriter.create(loc, f32Type, input, empty)->getResult(0); + + rewriter.replaceOp(op, f32Result); + return success(); + } +}; + +struct FP32ToI1Rewrite : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + + auto regionOps = triton::getRegionOps(op); + + if (regionOps.size() != 1 || !isa(regionOps.front())) + return rewriter.notifyMatchFailure(op, "only rewrite f32 to i1 op\n"); + + auto fpToSI = cast(regionOps.front()); + + if (!fpToSI->getOperandTypes()[0].isF32() || + !fpToSI->getResultTypes()[0].isInteger(1)) + return rewriter.notifyMatchFailure(op, "only rewrite f32 to i1 op\n"); + + Location loc = op.getLoc(); + + auto input = op.getInputs()[0]; + auto inputType = cast(input.getType()); + auto resultType = cast(op->getResultTypes()[0]); + + auto rank = inputType.getRank(); + SmallVector identityMaps( + 3, rewriter.getMultiDimIdentityMap(rank)); + SmallVector iterators( + rank, mlir::utils::IteratorType::parallel); + + auto I1Empty = rewriter.create(loc, resultType.getShape(), + rewriter.getIntegerType(1)); + + Value zeroF32Const = rewriter.create( + loc, APFloat(0.0f), rewriter.getF32Type()); + + auto zeroTensor = + rewriter + .create( + loc, ValueRange{zeroF32Const}, + ValueRange{rewriter.create( + loc, inputType.getShape(), rewriter.getF32Type())}) + .getResult(0); + + auto result = rewriter.create( + loc, TypeRange{resultType}, ValueRange{input, zeroTensor}, + ValueRange{I1Empty}, identityMaps, iterators, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value cmp = builder.create( + loc, arith::CmpFPredicate::ONE, args[0], args[1]); + builder.create(loc, cmp); + }); + + rewriter.replaceOp(op, result->getResult(0)); + return success(); + } +}; + struct AssertOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -3626,7 +3751,9 @@ void mlir::triton::populateLinalgToMKTypeConversionPatterns( patterns.add( patterns.getContext(), precisionPriority /* precisionPriority */); patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); + patterns + .add( + patterns.getContext()); // TODO: if need precision mode patterns.add, CastArgMinMaxOpIOToFloatPattern>( diff --git a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp index 17992beecc..8da126f1b9 100644 --- a/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/LinalgToMK/LinalgToMKPass.cpp @@ -93,7 +93,7 @@ class LinalgToMKPass : public triton::impl::LinalgToMKBase { auto reduceOps = llvm::map_to_vector(regionBlock->without_terminator(), [](Operation &op) { return &op; }); if (reduceOps.size() != 1) - return true; + return false; // TODO: Config according backend // TODO: Optimize for i1 reduction. i1 reduction is not supported // because memref.subviews may cause the offset to be inside the byte. diff --git a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp index aba2f65003..6a774bbc3a 100644 --- a/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp +++ b/third_party/tsingmicro/lib/Conversion/MKToTx81/MKToTx81.cpp @@ -1082,8 +1082,9 @@ struct MKRelationVSOpConversionPattern : public OpConversionPattern { auto outputType = cast(output.getType()); if (outputType.getElementType().isInteger(1)) { elemCount = ((elemCount + 7) / 8) * 8; - op->emitRemark() << "element count was expanded to a multiple of 8, may " - "access memory out of bounds!"; + LLVM_DEBUG(op->emitRemark() + << "element count was expanded to a multiple of 8, may " + "access memory out of bounds!"); } auto inputPtr = createAddressFromMemref(rewriter, loc, input); @@ -1425,6 +1426,37 @@ struct ElementwiseConversion : public OpConversionPattern { return success(); } + LogicalResult convertUIToFPOp(linalg::GenericOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto inputType = + dyn_cast(op.getInputs()[0].getType()).getElementType(); + auto outputType = + dyn_cast(op.getOutputs()[0].getType()).getElementType(); + if (inputType.isInteger(1) && (outputType.isF32() || outputType.isF16())) { + Location loc = op.getLoc(); + auto [inputPtr, sizes, strides] = + createMetadata(rewriter, loc, adaptor.getInputs()[0]); + auto outputPtr = + createAddressFromMemref(rewriter, loc, adaptor.getOutputs()[0]); + + auto elemCount = calculateElemCount(rewriter, op->getLoc(), sizes); + + auto outputType = dyn_cast(op.getOutputs()[0].getType()); + Data_Format srcFmt = getFormatCode(outputType); + + rewriter.create(loc, rewriter.getI64Type(), inputPtr, + outputPtr, elemCount, + rewriter.getI16IntegerAttr(srcFmt)); + rewriter.eraseOp(op); + + return success(); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported input/output type combination for integer to " + "FP conversion"); + } + } + LogicalResult matchAndRewrite(linalg::GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1592,6 +1624,9 @@ struct ElementwiseConversion : public OpConversionPattern { "FP conversion"); } }) + .Case([&](auto elemWiseOp) { + return convertUIToFPOp(op, adaptor, rewriter); + }) .Case([&](auto elemWiseOp) { // TODO: Need add more int to fp convert. auto inputType = dyn_cast(op.getInputs()[0].getType()) diff --git a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp index 697b3567da..1ced59142f 100644 --- a/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp +++ b/third_party/tsingmicro/lib/Conversion/StructuredToMemref/StructuredToMemref.cpp @@ -63,6 +63,16 @@ static memref::SubViewOp getSubview(int rank, ArrayRef dims, offsets, dims, strides); } +static OpFoldResult accumulateTargetOffset(tts::MakeTensorPtrOp op, + OpBuilder &b) { + Location loc = op->getLoc(); + OpFoldResult targetOffset = b.getIndexAttr(0); + for (auto o : op.getMixedOffsets()) { + targetOffset = addOFRs(targetOffset, o, loc, b); + } + return targetOffset; +} + namespace { struct MakeTensorPtrConverter @@ -124,270 +134,6 @@ struct MakeTensorPtrConverter return strides; } - static OpFoldResult accumulateTargetOffset(tts::MakeTensorPtrOp op, - OpBuilder &b) { - Location loc = op->getLoc(); - OpFoldResult targetOffset = b.getIndexAttr(0); - for (auto o : op.getMixedOffsets()) { - targetOffset = addOFRs(targetOffset, o, loc, b); - } - return targetOffset; - } - - std::pair - createSideBySideCastOps(tts::MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op->getLoc(); - auto resultShape = cast(op.getType()).getShape(); - - auto targetOffset = - ofrToIndexValue(accumulateTargetOffset(op, rewriter), loc, rewriter); - - //////////////////////////////////////////////////////////////////////////// - // - // Handling side-by-side wraparound - // - // Note: We do not support cases where the target has already overflown the - // number of columns! This is because in PtrAnalysis, the offset has already - // been collapsed into a single dimension, so it is ambiguous to determine - // whether the offset actually overflows or just refers to an element on the - // subsequent rows. - // - // Same limitations apply to the stacked wraparound case. - // - //////////////////////////////////////////////////////////////////////////// - // - // nextOffset - targetOffset = colSize - // d1 + d2 = colSize - // N - // x clampedOffset - // --------------------------*----------------*-----* - // | | nextOffset (might - // | targetOffset | overflow) - // y *----- *----------------| - // | | | | - // M |----- -----------------| - // | d2 d1 | - // -------------------------------------------- - // - // x = targetOffset % N - // nextOffset = x + colSize - // clampedOffset = min(nextOffset, N) - // d1 = clampedOffset - x - // - //////////////////////////////////////////////////////////////////////////// - - auto resultType = getResultMemrefType( - op, /* offset */ ShapedType::kDynamic, - /* staticStrides */ - SmallVector(resultShape.size(), ShapedType::kDynamic), - /* result shape */ - SmallVector{ - - // Row stays the same, but mlir doesn't allow this anymore. Put - // dynamic. - ShapedType::kDynamic, - - // Column is dynamic, in most cases, this - // should be the same as the original column. - // The last chunk may be smaller due to - // wrapping around. - ShapedType::kDynamic}); - - Value rowSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[0])); - Value colSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[1])); - - Value modN = ofrToIndexValue(op.getMixedShape()[1], loc, rewriter); - - Value x = rewriter.create(loc, targetOffset, modN); - Value y = rewriter.create(loc, targetOffset, x); - - SmallVector strideVals = - ofrsToIndexValues(op.getMixedStrides(), loc, rewriter); - - // First chunk - Value nextOffset = rewriter.create(loc, x, colSize); - Value clampedOffset = - rewriter.create(loc, nextOffset, modN); - Value d1 = rewriter.create(loc, clampedOffset, x); - SmallVector sizes1{rowSize, d1}; - - auto cast1 = rewriter.create( - loc, resultType, adaptor.getBase(), targetOffset, sizes1, strideVals); - - // Second chunk - Value d2 = rewriter.create(loc, colSize, d1); - SmallVector sizes2{rowSize, d2}; - - auto cast2 = rewriter.create( - loc, resultType, adaptor.getBase(), y, sizes2, strideVals); - - return {cast1, cast2}; - } - - std::pair - createStackedCastOps(tts::MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - auto loc = op->getLoc(); - auto resultShape = cast(op.getType()).getShape(); - - assert(resultShape.size() == 2); - - auto targetOffset = - ofrToIndexValue(accumulateTargetOffset(op, rewriter), loc, rewriter); - - //////////////////////////////////////////////////////////////////////////// - // - // Handling stacked wraparound - // - // We do not support cases where the target offset has already overflown the - // number of rows. See side-by-side wraparound for details. - // - //////////////////////////////////////////////////////////////////////////// - // We're loading a tensor of dim (rowSize, colSize) - // d1 + d2 = rowSize - // d2 is the number of rows that overflow - // - // cols - // - // wrappedAroundOff - // --------------*------------*-------- - // | d2 | | | - // | |------------| | - // rows| | - // | | - // | targetOffset | - // | *------------| | - // | | | | - // | d1 | | | - // | | clampedOff | | - // --------------*--------------------- - // | overflow | - // *------------- - // nextOff - // - // wrappedAroundOff = targetOffset % cols - // clampedOff = (rows * strideRows) + wrappedAroundOff - // ~~~~~~~~~~~~~~~~~ - // ^ - // | - // We have already computed - // rows * strideRows = modRow = shape[1] - // in TritonToStructured - // - // clampedOff - targetOffset - // d1 = -------------------- - // strideRows - // - //////////////////////////////////////////////////////////////////////////// - // cols - // - // wrappedAroundOff - // --------------*------------*-------- - // | | - // | targetOffset | - // | *------------| | - // | | | | - // | | | | - // rows| rowSize | | | - // | | | | - // | | | | - // | *------------| | - // | nextOff | - // | | - // | clampedOff | - // --------------*--------------------- - // - // d1 = rowSize - // - // d2 = 0 - - auto resultType = getResultMemrefType( - op, /* offset */ ShapedType::kDynamic, - /* staticStrides */ - SmallVector(resultShape.size(), ShapedType::kDynamic), - /* result shape */ - SmallVector{ - // Row is dynamic, in most cases, this should - // be the same as the original row. The last - // chunk may be smaller due to wrapping - // around. - ShapedType::kDynamic, - - // Col stays the same, which is resultShape[1], but mlir doesn't - // allow this anymore. So we put dynamic instead. - ShapedType::kDynamic}); - - Value rowSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[0])); - Value colSize = rewriter.create( - loc, rewriter.getIndexAttr(op.getSizes()[1])); - - Value strideRow = ofrToIndexValue(op.getMixedStrides()[0], loc, rewriter); - Value strideCol = ofrToIndexValue(op.getMixedStrides()[1], loc, rewriter); - - Value modRow = op.getShape()[0]; - - // First chunk - Value wrappedAroundOff = - rewriter.create(loc, targetOffset, strideRow); - Value clampedOff = - rewriter.create(loc, modRow, wrappedAroundOff); - Value d1 = rewriter.create(loc, clampedOff, targetOffset); - d1 = rewriter.create(loc, d1, strideRow); - d1 = rewriter.create(loc, d1, rowSize); - - SmallVector sizes1{d1, colSize}; - memref::ReinterpretCastOp cast1 = - rewriter.create( - loc, resultType, adaptor.getBase(), targetOffset, sizes1, - ValueRange{strideRow, strideCol}); - - // Second chunk - Value d2 = rewriter.create(loc, rowSize, d1); - SmallVector sizes2{d2, colSize}; - memref::ReinterpretCastOp cast2 = - rewriter.create( - loc, resultType, adaptor.getBase(), wrappedAroundOff, sizes2, - ValueRange{strideRow, strideCol}); - - return {cast1, cast2}; - } - - LogicalResult rewriteSplitPtr(tts::MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - auto parentShape = op.getStaticShape(); - - SmallVector casts; - StringRef wrapType; - - if (parentShape[0] == ShapedType::kDynamic) { - // Stacked case - assert(parentShape[1] == 0); - auto [cast1, cast2] = createStackedCastOps(op, adaptor, rewriter); - casts = {cast1.getResult(), cast2.getResult()}; - wrapType = WRAP_STACKED; - } else { - assert(parentShape[0] == 0); - auto [cast1, cast2] = createSideBySideCastOps(op, adaptor, rewriter); - casts = {cast1.getResult(), cast2.getResult()}; - wrapType = WRAP_SIDE_BY_SIDE; - } - - auto combinedCast = rewriter.create( - op.getLoc(), op.getType(), casts); - - combinedCast->setAttr(wrapType, rewriter.getUnitAttr()); - - rewriter.replaceOp(op, combinedCast); - - return success(); - } - LogicalResult rewritePtr(ArrayRef resultShape, bool isBlockPtr, tts::MakeTensorPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -450,142 +196,473 @@ struct MakeTensorPtrConverter } if (op.isSplitPtr()) { - return rewriteSplitPtr(op, adaptor, rewriter); + return success(); } return failure(); } }; -struct LoadConverter : public OpConversionPattern { -private: - using OpConversionPattern::OpConversionPattern; +memref::SubViewOp createSubview(Value src, ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides, Location loc, + ConversionPatternRewriter &rewriter) { + auto srcType = cast(src.getType()); + auto dstType = + memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); + return rewriter.create(loc, cast(dstType), src, + offsets, sizes, strides); +} - void createSideBySideCopies(Value block1, Value block2, Value dst, - Location loc, - ConversionPatternRewriter &rewriter) const { - - auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - - auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); - - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); - - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); - - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{zero, block1Col}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); - - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); +Value createCastOps(tts::MakeTensorPtrOp op, + ConversionPatternRewriter &rewriter, Value start, + SmallVector sizesValues, + SmallVector strideVals) { + + Type elemType = + cast(op.getBase().getType()).getPointeeType(); + + auto unrankedMemrefType = UnrankedMemRefType::get(elemType, 0); + // WARNING: TypeConverter cannot automatically insert + // `UnrealizedConversionCastOp` through the materialization mechanism. + auto unrankedMemref = rewriter + .create( + op->getLoc(), unrankedMemrefType, op.getBase()) + ->getResults()[0]; + + auto layout = StridedLayoutAttr::get( + op.getContext(), ShapedType::kDynamic, + SmallVector(sizesValues.size(), ShapedType::kDynamic)); + MemRefType resultType = MemRefType::get( + SmallVector(sizesValues.size(), ShapedType::kDynamic), elemType, + layout); + auto block = rewriter.create( + op->getLoc(), resultType, unrankedMemref, start, sizesValues, strideVals); + return block; +} + +std::pair +getMemSubviews(SmallVector &dims, Value block, Location loc, + int64_t splitDim, ConversionPatternRewriter &rewriter) { + + auto rank = dims.size(); + OpFoldResult maskSize = + rewriter.create(loc, block, splitDim).getResult(); + + OpFoldResult subviewDimFull = dims[splitDim]; + OpFoldResult subviewDim = minOFRs(maskSize, subviewDimFull, loc, rewriter); + + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + + SmallVector sizes(dims.begin(), dims.end()); + sizes[splitDim] = subviewDim; + + auto sv = createSubview(block, offsets, sizes, strides, loc, rewriter); + auto remainMask = rewriter.create( + loc, ofrToIndexValue(subviewDimFull, loc, rewriter), + ofrToIndexValue(subviewDim, loc, rewriter)); + + return {sv, remainMask}; +} + +void createMemCopies(Value block, Value dst, Location loc, + ConversionPatternRewriter &rewriter, Value &dstOffset, + int64_t splitDim, bool isLoadToDst) { + auto zero = rewriter.create(loc, rewriter.getIndexAttr(0)); + + auto one = rewriter.create(loc, rewriter.getIndexAttr(1)); + + auto rank = cast(dst.getType()).getRank(); + SmallVector blockShape; + for (int i = 0; i < rank; i++) { + blockShape.push_back(rewriter.create(loc, block, i)); } - void createStackedCopies(Value block1, Value block2, Value dst, Location loc, - ConversionPatternRewriter &rewriter) const { + SmallVector dstOffsets(rank, zero); + dstOffsets[splitDim] = dstOffset; + + auto blockDst = + rewriter.create(loc, dst, + /* offsets */ + dstOffsets, + /* sizes */ + blockShape, + /* strides */ + SmallVector(rank, one)); + dstOffset = + rewriter.create(loc, dstOffset, blockShape[splitDim]); + + if (isLoadToDst) { + rewriter.create(loc, block, blockDst); + } else { + rewriter.create(loc, blockDst, block); + } +} - auto zero = - rewriter.create(loc, rewriter.getIndexAttr(0)); - auto one = - rewriter.create(loc, rewriter.getIndexAttr(1)); - - Value block1Row = rewriter.create(loc, block1, 0); - Value block1Col = rewriter.create(loc, block1, 1); - - Value block2Row = rewriter.create(loc, block2, 0); - Value block2Col = rewriter.create(loc, block2, 1); - - auto block1Dst = - rewriter.create(loc, dst, /* offsets */ - ValueRange{zero, zero}, - /* sizes */ - ValueRange{block1Row, block1Col}, - /* strides */ - ValueRange{one, one}); - - auto block2Dst = - rewriter.create(loc, dst, - /* offsets */ - ValueRange{block1Row, zero}, - /* sizes */ - ValueRange{block2Row, block2Col}, - /* strides */ - ValueRange{one, one}); - - rewriter.create(loc, block1, block1Dst); - rewriter.create(loc, block2, block2Dst); +Value processMemSubviewCopies(Location loc, ConversionPatternRewriter &rewriter, + Value alloc, Value block, + SmallVector mixedDims, + Value &allocOffset, int64_t splitDim, + bool isLoadToDst) { + + Value subview; + Value remainMask; + if (mixedDims.empty()) { + subview = block; + } else { + auto res = getMemSubviews(mixedDims, block, loc, splitDim, rewriter); + subview = res.first; + remainMask = res.second; } - memref::SubViewOp createSubview(Value src, ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides, Location loc, - ConversionPatternRewriter &rewriter) const { - auto srcType = cast(src.getType()); - auto dstType = - memref::SubViewOp::inferResultType(srcType, offsets, sizes, strides); - return rewriter.create(loc, cast(dstType), - src, offsets, sizes, strides); + createMemCopies(subview, alloc, loc, rewriter, allocOffset, splitDim, + isLoadToDst); + return remainMask; +} + +void rewriteSideBySideMemAccess(tts::MakeTensorPtrOp makeTensorPtrOp, + ConversionPatternRewriter &rewriter, + Value alloc, + SmallVector mixedDims, + bool isLoadToDst) { + assert(makeTensorPtrOp.getStaticShape().size() == 1 || + makeTensorPtrOp.getStaticShape()[0] == 0); + auto loc = makeTensorPtrOp->getLoc(); + auto targetOffset = ofrToIndexValue( + accumulateTargetOffset(makeTensorPtrOp, rewriter), loc, rewriter); + + //////////////////////////////////////////////////////////////////////////// + // + // Handling side-by-side wraparound + // + // Same limitations apply to the stacked wraparound case. + // + //////////////////////////////////////////////////////////////////////////// + // + // nextOffset - targetOffset = colSize + // d1 + d2 = colSize + // N + // x clampedOffset + // --------------------------*----------------*-----* + // | | nextOffset (might + // | targetOffset | overflow) + // y *----- *----------------| + // | | | | + // M |----- -----------------| + // | d2 d1 | + // -------------------------------------------- + // + // x = targetOffset % N + // offset_dim_0 = scaled_offset_0 + // offset_dim_1 = scaled_offset_1 + // col_start = scaled_offset_1 % N + // remainSize = colSize + // size = N - col_start + // while (remainSize > size): + // reinterpret (col_start, size, stride ) + // remainSize = remainSize - size + // col_start = (scaled_offset_0 + size) %N + // size = N + // + // reinterpret (col_start, remainSize, stride ) + // + //////////////////////////////////////////////////////////////////////////// + auto rank = cast(makeTensorPtrOp.getType()).getRank(); + SmallVector scaledOffset(rank); + auto offsets = makeTensorPtrOp.getMixedOffsets(); + std::transform(offsets.begin(), offsets.end(), scaledOffset.begin(), + [&](auto val) { return ofrToIndexValue(val, loc, rewriter); }); + auto lastDim = rank - 1; + + assert(rank == makeTensorPtrOp.getSizes().size()); + // Data block shape to be read + auto sizesInt = makeTensorPtrOp.getSizes(); + SmallVector sizesValues(rank); + std::transform(sizesInt.begin(), sizesInt.end(), sizesValues.begin(), + [&](auto val) { + return rewriter.create( + loc, rewriter.getIndexAttr(val)); + }); + // Total side by side size + Value totalSize = sizesValues.back(); + + // NOTE: We use `scaledOffset[lastdim]` for modulo because the loop and the + // modulo dimension cannot be in the same dimension (otherwise ptrAnalysis + // cannot analyze `make_tptr` of `splitMemory`). + Value N = + ofrToIndexValue(makeTensorPtrOp.getMixedShape()[lastDim], loc, rewriter); + Value x = rewriter.create(loc, scaledOffset[lastDim], N); + Value y = + rewriter.create(loc, targetOffset, scaledOffset[lastDim]); + SmallVector strideVals = + ofrsToIndexValues(makeTensorPtrOp.getMixedStrides(), loc, rewriter); + + Value remainSize = totalSize; + Value size = rewriter.create(loc, N, x); + Value colStart = rewriter.create(loc, y, x); + Value allocOffset = rewriter.create(loc, 0); + SmallVector typeR{remainSize.getType(), size.getType(), + colStart.getType(), allocOffset.getType()}; + SmallVector valueR{remainSize, size, colStart, allocOffset}; + if (!mixedDims.empty()) { + Value mixedDim = ofrsToIndexValues(mixedDims[lastDim], loc, rewriter)[0]; + typeR.push_back(mixedDim.getType()); + valueR.push_back(mixedDim); + } + auto whileOp = rewriter.create( + loc, typeR, valueR, + /*beforeBuilder=*/ + [&](OpBuilder &b, Location loc, ValueRange args) { + Value cond = b.create(loc, arith::CmpIPredicate::sgt, + args[0], args[1]); + b.create(loc, cond, args); + }, + /*afterBuilder=*/ + [&](OpBuilder &b, Location loc, ValueRange args) { + Value remainSize = args[0]; + Value size = args[1]; + Value colStart = args[2]; + Value allocOffset = args[3]; + sizesValues[lastDim] = size; + if (!mixedDims.empty()) { + Value mixedDim = args[4]; + mixedDims[lastDim] = mixedDim; + } + Value block = createCastOps(makeTensorPtrOp, rewriter, colStart, + sizesValues, strideVals); + Value mixedDim = + processMemSubviewCopies(loc, rewriter, alloc, block, mixedDims, + allocOffset, lastDim, isLoadToDst); + remainSize = b.create(loc, remainSize, size); + colStart = b.create(loc, colStart, size); + colStart = b.create(loc, colStart, N); + colStart = b.create(loc, colStart, y); + size = N; + SmallVector newArgs{remainSize, size, colStart, allocOffset}; + if (!mixedDims.empty()) { + newArgs.push_back(mixedDim); + } + b.create(loc, newArgs); + }); + + remainSize = whileOp->getResult(0); + colStart = whileOp->getResult(2); + sizesValues[lastDim] = remainSize; + allocOffset = whileOp->getResult(3); + if (!mixedDims.empty()) { + mixedDims[lastDim] = whileOp->getResult(4); } - std::pair - getSideBySideSubviews(ArrayRef dims, Value block1, Value block2, - Location loc, - ConversionPatternRewriter &rewriter) const { - OpFoldResult subviewRowFull = dims[0]; - OpFoldResult subviewColFull = dims[1]; - OpFoldResult col1 = - rewriter.create(loc, block1, 1).getResult(); - OpFoldResult subviewCol1 = minOFRs(col1, subviewColFull, loc, rewriter); - OpFoldResult subviewCol2 = - subOFRs(subviewColFull, subviewCol1, loc, rewriter); - - SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); - SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); - auto sv1 = createSubview(block1, offsets, {subviewRowFull, subviewCol1}, - strides, loc, rewriter); - auto sv2 = createSubview(block2, offsets, {subviewRowFull, subviewCol2}, - strides, loc, rewriter); - - return {sv1, sv2}; + Value block = createCastOps(makeTensorPtrOp, rewriter, colStart, sizesValues, + strideVals); + processMemSubviewCopies(loc, rewriter, alloc, block, mixedDims, allocOffset, + lastDim, isLoadToDst); +} + +void rewriteStackedMemAccess(tts::MakeTensorPtrOp makeTensorPtrOp, + ConversionPatternRewriter &rewriter, Value alloc, + SmallVector mixedDims, + bool isLoadToDst) { + assert(makeTensorPtrOp.getStaticShape()[1] == 0); + + auto loc = makeTensorPtrOp->getLoc(); + auto resultShape = + cast(makeTensorPtrOp.getType()).getShape(); + + assert(resultShape.size() == 2); + auto rank = cast(makeTensorPtrOp.getType()).getRank(); + auto targetOffset = ofrToIndexValue( + accumulateTargetOffset(makeTensorPtrOp, rewriter), loc, rewriter); + + //////////////////////////////////////////////////////////////////////////// + // + // Handling stacked wraparound + // See side-by-side wraparound for details. + // + //////////////////////////////////////////////////////////////////////////// + // We're loading a tensor of dim (rowSize, colSize) + // d1 + d2 = rowSize + // d2 is the number of rows that overflow + // + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | d2 | | | + // | |------------| | + // rows| | + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | d1 | | | + // | | clampedOff | | + // --------------*--------------------- + // | overflow | + // *------------- + // nextOff + // + // wrappedAroundOff = targetOffset % cols + // clampedOff = (rows * strideRows) + wrappedAroundOff + // ~~~~~~~~~~~~~~~~~ + // ^ + // | + // We have already computed + // rows * strideRows = modRow = shape[1] + // in TritonToStructured + // + // clampedOff - targetOffset + // d1 = -------------------- + // strideRows + // + // N = stride[0] + // M = shape[0] % N + // row_start = targetOffset / N % M + // remainSize = rowsize + // size = M - row_start + // start = row_start * N + targetOffset % N + // while (remainSize > size): + // reinterpret (start, size, stride ) + // remainSize = remainSize - size + // start = scaled_offset_1 + // size = M + // reinterpret (start, remainSize, stride ) + //////////////////////////////////////////////////////////////////////////// + // cols + // + // wrappedAroundOff + // --------------*------------*-------- + // | | + // | targetOffset | + // | *------------| | + // | | | | + // | | | | + // rows| rowSize | | | + // | | | | + // | | | | + // | *------------| | + // | nextOff | + // | | + // | clampedOff | + // --------------*--------------------- + // + // d1 = rowSize + // + // d2 = 0 + Value modM = makeTensorPtrOp.getShape()[0]; + Value N = + ofrToIndexValue(makeTensorPtrOp.getMixedStrides()[0], loc, rewriter); + + auto sizesInt = makeTensorPtrOp.getSizes(); + SmallVector sizesValues(rank); + std::transform(sizesInt.begin(), sizesInt.end(), sizesValues.begin(), + [&](auto val) { + return rewriter.create( + loc, rewriter.getIndexAttr(val)); + }); + SmallVector strideVals = + ofrsToIndexValues(makeTensorPtrOp.getMixedStrides(), loc, rewriter); + + // NOTE: Here, we need to use `targetOffset` for integer division, instead of + // `scaledOffset[1]` as in `sidebyside`. This is because the offset of the + // column loop analyzed by ptrAnalysis is added to `scaledOffset[0]` (row), + // and we assume that the offset in the 1-dimensional dimension must be less + // than N. + Value M = rewriter.create(loc, modM, N); + Value rowStart = rewriter.create(loc, targetOffset, N); + rowStart = rewriter.create(loc, rowStart, M); + Value colStart = rewriter.create(loc, targetOffset, N); + + Value remainSize = rewriter.create( + loc, rewriter.getIndexAttr(makeTensorPtrOp.getSizes()[0])); + Value size = rewriter.create(loc, M, rowStart); + Value start = rewriter.create(loc, rowStart, N); + start = rewriter.create(loc, start, colStart); + Value allocOffset = rewriter.create(loc, 0); + SmallVector typeR{remainSize.getType(), size.getType(), start.getType(), + allocOffset.getType()}; + SmallVector valueR{remainSize, size, start, allocOffset}; + if (!mixedDims.empty()) { + Value mixedDim = ofrsToIndexValues(mixedDims[0], loc, rewriter)[0]; + typeR.push_back(mixedDim.getType()); + valueR.push_back(mixedDim); + } + auto whileOp = rewriter.create( + loc, typeR, valueR, + /*beforeBuilder=*/ + [&](OpBuilder &b, Location loc, ValueRange args) { + Value cond = b.create(loc, arith::CmpIPredicate::sgt, + args[0], args[1]); + b.create(loc, cond, args); + }, + /*afterBuilder=*/ + [&](OpBuilder &b, Location loc, ValueRange args) { + Value remainSize = args[0]; + Value size = args[1]; + Value start = args[2]; + Value allocOffset = args[3]; + sizesValues[0] = size; + if (!mixedDims.empty()) { + Value mixedDim = args[4]; + mixedDims[0] = mixedDim; + } + Value block = createCastOps(makeTensorPtrOp, rewriter, start, + sizesValues, strideVals); + Value mixedDim = processMemSubviewCopies( + loc, rewriter, alloc, block, mixedDims, allocOffset, + 0 /*dim of mod row*/, isLoadToDst); + remainSize = b.create(loc, remainSize, size); + Value addOffsets = b.create(loc, size, N); + start = b.create(loc, start, addOffsets); + start = b.create(loc, start, modM); + size = M; + SmallVector newArgs{remainSize, size, start, allocOffset}; + if (!mixedDims.empty()) { + newArgs.push_back(mixedDim); + } + b.create(loc, newArgs); + }); + remainSize = whileOp->getResult(0); + start = whileOp->getResult(2); + sizesValues[0] = remainSize; + allocOffset = whileOp->getResult(3); + if (!mixedDims.empty()) { + mixedDims[0] = whileOp->getResult(4); + } + Value block = + createCastOps(makeTensorPtrOp, rewriter, start, sizesValues, strideVals); + processMemSubviewCopies(loc, rewriter, alloc, block, mixedDims, allocOffset, + 0 /*dim of mod row*/, isLoadToDst); +} + +void rewriteMakeTensorPtrAndMemAccess(tts::MakeTensorPtrOp makeTensorPtrOp, + ConversionPatternRewriter &rewriter, + Value alloc, + SmallVector mixedDims, + bool isLoadToDst) { + auto parentShape = makeTensorPtrOp.getStaticShape(); + if (parentShape.size() > 1 && parentShape[0] == ShapedType::kDynamic) { + rewriteStackedMemAccess(makeTensorPtrOp, rewriter, alloc, mixedDims, + isLoadToDst); + } else { + rewriteSideBySideMemAccess(makeTensorPtrOp, rewriter, alloc, mixedDims, + isLoadToDst); } +} - std::pair - getStackedSubviews(ArrayRef dims, Value block1, Value block2, - const Location loc, - ConversionPatternRewriter &rewriter) const { - OpFoldResult subviewRowFull = dims[0]; - OpFoldResult subviewColFull = dims[1]; - OpFoldResult row1 = - rewriter.create(loc, block1, 0).getResult(); - OpFoldResult subviewRow1 = minOFRs(row1, subviewRowFull, loc, rewriter); - OpFoldResult subviewRow2 = - subOFRs(subviewRowFull, subviewRow1, loc, rewriter); - - SmallVector offsets(dims.size(), rewriter.getIndexAttr(0)); - SmallVector strides(dims.size(), rewriter.getIndexAttr(1)); - auto sv1 = createSubview(block1, offsets, {subviewRow1, subviewColFull}, - strides, loc, rewriter); - auto sv2 = createSubview(block2, offsets, {subviewRow2, subviewColFull}, - strides, loc, rewriter); - return {sv1, sv2}; +tts::MakeTensorPtrOp isSplitMemoryAccess(Operation *op) { + auto makeTensorPtrOp = + op->getOperand(0).getDefiningOp(); + if (makeTensorPtrOp && makeTensorPtrOp.isSplitPtr()) { + return makeTensorPtrOp; } + return nullptr; +} + +struct LoadConverter : public OpConversionPattern { +private: + using OpConversionPattern::OpConversionPattern; LogicalResult rewriteStructuredLoad(tts::LoadOp op, OpAdaptor adaptor, @@ -605,19 +682,10 @@ struct LoadConverter : public OpConversionPattern { // No mask assert(!other && "other value used in non-masked load"); - if (auto unrealizedCast = ptr.getDefiningOp()) { + if (auto makeTensorPtrOp = isSplitMemoryAccess(op)) { alloc = rewriter.create(loc, memrefType); - auto memrefs = unrealizedCast.getOperands(); - auto block1 = memrefs[0]; - auto block2 = memrefs[1]; - - if (unrealizedCast->hasAttr(WRAP_SIDE_BY_SIDE)) { - createSideBySideCopies(block1, block2, alloc, loc, rewriter); - } else if (unrealizedCast->hasAttr(WRAP_STACKED)) { - createStackedCopies(block1, block2, alloc, loc, rewriter); - } else { - llvm_unreachable("unexpected wraparound type"); - } + rewriteMakeTensorPtrAndMemAccess(makeTensorPtrOp, rewriter, alloc, + SmallVector{}, true); } else { alloc = rewriter.create(loc, memrefType, ptr); } @@ -648,8 +716,8 @@ struct LoadConverter : public OpConversionPattern { auto other = op.getOther(); if (!other) { - op->emitRemark( - "Masked load without other value, using zero padding instead\n"); + LLVM_DEBUG(op->emitRemark( + "Masked load without other value, using zero padding instead\n")); // FIXME: Different reduction op need different reduce base value other = rewriter.create( loc, elemType, rewriter.getZeroAttr(elemType)); @@ -683,26 +751,9 @@ struct LoadConverter : public OpConversionPattern { b.create(loc); }); - if (auto unrealizedCast = ptr.getDefiningOp()) { - - auto memrefs = unrealizedCast.getOperands(); - auto block1 = memrefs[0]; - auto block2 = memrefs[1]; - - if (unrealizedCast->hasAttr(WRAP_SIDE_BY_SIDE)) { - auto [subview1, subview2] = - getSideBySideSubviews(mixedDims, block1, block2, loc, rewriter); - createSideBySideCopies(subview1, subview2, alloc, loc, rewriter); - } else if (unrealizedCast->hasAttr(WRAP_STACKED)) { - auto [subview1, subview2] = - getStackedSubviews(mixedDims, block1, block2, loc, rewriter); - createStackedCopies(subview1, subview2, alloc, loc, rewriter); - } else { - llvm_unreachable("unexpected wraparound type"); - } - - rewriter.eraseOp(unrealizedCast); - + if (auto makeTensorPtrOp = isSplitMemoryAccess(op)) { + rewriteMakeTensorPtrAndMemAccess(makeTensorPtrOp, rewriter, alloc, + mixedDims, true); } else { memref::SubViewOp srcSubview = getSubview(tensorType.getRank(), mixedDims, ptr, loc, rewriter); @@ -748,18 +799,28 @@ struct StoreConverter : public OpConversionPattern { strides); } -public: - LogicalResult - matchAndRewrite(tts::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult rewriteMaskedStore(tts::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(op.hasMask()); + auto loc = op.getLoc(); auto ptr = adaptor.getPtr(); auto storeValue = op.getValue(); + auto tensorType = cast(storeValue.getType()); + auto elemType = tensorType.getElementType(); auto rank = cast(storeValue.getType()).getRank(); - if (op.hasMask()) { - auto mixedDims = op.getMixedMaskDims(); - + auto mixedDims = op.getMixedMaskDims(); + if (auto makeTensorPtrOp = isSplitMemoryAccess(op)) { + auto srcSlice = + getExtractSlice(rank, mixedDims, storeValue, loc, rewriter); + auto srcType = cast(srcSlice.getType()); + auto srcSliceMemRef = rewriter.create( + loc, MemRefType::get(srcType.getShape(), srcType.getElementType()), + srcSlice); + rewriteMakeTensorPtrAndMemAccess(makeTensorPtrOp, rewriter, + srcSliceMemRef, mixedDims, false); + } else { auto srcSlice = getExtractSlice(rank, mixedDims, storeValue, loc, rewriter); auto dstSubview = getSubview(rank, mixedDims, ptr, loc, rewriter); @@ -767,6 +828,30 @@ struct StoreConverter : public OpConversionPattern { auto storeOp = rewriter.create( loc, srcSlice, dstSubview); storeOp.setWritable(true); + } + rewriter.eraseOp(op); + return success(); + } + + LogicalResult + rewriteStructuredStore(tts::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + assert(!op.hasMask()); + + auto loc = op.getLoc(); + auto ptr = adaptor.getPtr(); + auto storeValue = op.getValue(); + + auto tensorType = cast(storeValue.getType()); + auto elemType = tensorType.getElementType(); + auto rank = cast(storeValue.getType()).getRank(); + MemRefType memrefType = MemRefType::get(tensorType.getShape(), elemType); + + if (auto makeTensorPtrOp = isSplitMemoryAccess(op)) { + auto srcMemRef = rewriter.create( + loc, memrefType, storeValue); + rewriteMakeTensorPtrAndMemAccess(makeTensorPtrOp, rewriter, srcMemRef, + SmallVector{}, false); } else { auto storeOp = rewriter.create( loc, storeValue, ptr); @@ -776,6 +861,17 @@ struct StoreConverter : public OpConversionPattern { rewriter.eraseOp(op); return success(); } + +public: + LogicalResult + matchAndRewrite(tts::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.hasMask()) { + return rewriteMaskedStore(op, adaptor, rewriter); + } else { + return rewriteStructuredStore(op, adaptor, rewriter); + } + } }; struct AtomicRMWOpConverter : public OpConversionPattern { diff --git a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp index eb9cb698cc..65973b5f86 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToStructured/TritonToStructuredPass.cpp @@ -322,14 +322,14 @@ class TritonToStructuredPass ptrAnalysis.initializeMaybeStructuredArgs(moduleOp); if (failed(ptrAnalysis.rewriteOp(moduleOp, useUnsafeMask))) { - moduleOp->emitWarning("PtrAnalysis failed"); + LLVM_DEBUG(moduleOp->emitWarning("PtrAnalysis failed")); } // Now that all the PtrStates have been populated, we can wire up the states // with the tts.get_structured_state ops inserted in the prepass. moduleOp.walk([&ptrAnalysis](tts::GetStructuredStateOp op) { if (failed(ptrAnalysis.rewriteGetStructuredStateOp(op))) { - op.emitWarning("Rewriting GetStructuredStateOp failed."); + LLVM_DEBUG(op.emitWarning("Rewriting GetStructuredStateOp failed.")); } }); } diff --git a/third_party/tsingmicro/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp b/third_party/tsingmicro/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp index c95b3d4dd3..65ca27d556 100644 --- a/third_party/tsingmicro/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp +++ b/third_party/tsingmicro/lib/Conversion/TritonToUnstructured/TritonToUnstructuredPass.cpp @@ -320,16 +320,40 @@ class TritonToUnstructuredPass return success(); }) .Case([&](triton::BitcastOp bitcast) { + auto resPtrType = bitcast.getType(); OpBuilder b{bitcast}; auto loc = bitcast->getLoc(); auto offsetInfo = offsetMap.at(bitcast.getOperand()); - - auto newBitcast = b.create( - loc, bitcast.getType(), offsetInfo.ptr); + auto srcPtrType = offsetInfo.ptrType; + assert((triton::isPtrTypeLike(resPtrType) && + triton::isPtrTypeLike(srcPtrType)) && + "unexpected bitcast type"); + Type resType = + isa(resPtrType) + ? cast(resPtrType).getElementType() + : resPtrType; + Type srcType = + isa(srcPtrType) + ? cast(srcPtrType).getElementType() + : srcPtrType; + assert(((cast(srcType) + .getPointeeType() + .isInteger(1) && + cast(resType) + .getPointeeType() + .isInteger(8)) || + (cast(srcType) + .getPointeeType() + .isInteger(8) && + cast(resType) + .getPointeeType() + .isInteger(1))) && + "only bitcast between i1 and i8 pointer is supported"); + auto newBitcast = + b.create(loc, resType, offsetInfo.ptr); bitcast->replaceAllUsesWith(newBitcast); - - PtrOffset newOffsetInfo{newBitcast, offsetInfo.ptrType, + PtrOffset newOffsetInfo{newBitcast, resPtrType, offsetInfo.bitWidth, offsetInfo.offset}; offsetMap.insert({newBitcast, newOffsetInfo}); @@ -388,11 +412,11 @@ class TritonToUnstructuredPass auto offsetInfo = offsetMap.at(ptr); OpBuilder b{op}; - auto clone = - b.create(op->getLoc(), op->getName().getIdentifier(), - ValueRange{offsetInfo.offset}, - TypeRange{getPtrOffsetType( - resType, offsetInfo.bitWidth)}); + auto clone = b.create( + op->getLoc(), op->getName().getIdentifier(), + ValueRange{offsetInfo.offset}, + TypeRange{getPtrOffsetType(resType, offsetInfo.bitWidth)}, + op->getAttrs()); PtrOffset newOffsetInfo{offsetInfo.ptr, resType, offsetInfo.bitWidth, @@ -441,6 +465,13 @@ class TritonToUnstructuredPass auto offsetType = getPtrOffsetType(offsetInfo.ptrType, offsetInfo.bitWidth); + // In order to keep types of the operands consistent, we need + // to replace the base pointer if it is directly from the + // kernel arguments. + if (ptrArgs.contains(init)) { + forOp->setOperand(argIndex, offsetInfo.offset); + } + // We're setting both the types of the iter-arg and the // corresponding result directly to the offset type. // At this point, the IR is in an invalid state because the @@ -474,13 +505,62 @@ class TritonToUnstructuredPass return success(); }) + .Case([&](scf::WhileOp whileOp) { + auto argIndex = use.getOperandNumber(); + auto init = whileOp.getInits()[argIndex]; + auto offsetInfo = offsetMap.at(init); + auto offsetType = + getPtrOffsetType(offsetInfo.ptrType, offsetInfo.bitWidth); + + // In order to keep types of the operands consistent, we need + // to replace the base pointer if it is directly from the + // kernel arguments. + if (ptrArgs.contains(init)) { + whileOp->setOperand(argIndex, offsetInfo.offset); + } + auto beforeArg = whileOp.getBeforeArguments()[argIndex]; + beforeArg.setType(offsetType); + + auto afterArg = whileOp.getAfterArguments()[argIndex]; + afterArg.setType(offsetType); + + auto res = whileOp->getOpResult(argIndex); + res.setType(offsetType); + + PtrOffset beforeArgOffset{offsetInfo.ptr, offsetInfo.ptrType, + offsetInfo.bitWidth, beforeArg}; + offsetMap.insert({ + beforeArg, + beforeArgOffset, + }); + + PtrOffset afterArgOffset{offsetInfo.ptr, offsetInfo.ptrType, + offsetInfo.bitWidth, afterArg}; + + offsetMap.insert({ + afterArg, + afterArgOffset, + }); + + PtrOffset resOffset{offsetInfo.ptr, offsetInfo.ptrType, + offsetInfo.bitWidth, res}; + offsetMap.insert({ + res, + resOffset, + }); + workList.push(beforeArg); + workList.push(afterArg); + workList.push(res); + + return success(); + }) .Case( [&](Operation *op) { ptrUsers.push_back(op); return success(); }) - - .Case([](auto) { return success(); }) + .Case( + [](auto) { return success(); }) .Case([](triton::CatOp op) { op->emitError("Do not support gather / scatter with multiple " "bases yet"); @@ -560,8 +640,8 @@ class TritonToUnstructuredPass } else { // MakeTensorPtrOp only takes i32 offsets, so we need // to truncate if the offsets were already in i64 - makeTensorPtr.emitWarning( - "truncating offsets which may result in data loss"); + LLVM_DEBUG(makeTensorPtr.emitWarning( + "truncating offsets which may result in data loss")); baseOffset = b.create(loc, currOffType, baseOffset); } @@ -619,9 +699,9 @@ class TritonToUnstructuredPass void runOnOperation() override { if (failed(processUnstructuredPtrs(offsetBitWidth))) { - getOperation()->emitWarning( + LLVM_DEBUG(getOperation()->emitWarning( "Cannot transform tensor of pointers into a single base pointer " - "with tensor of offsets"); + "with tensor of offsets")); return; } diff --git a/third_party/tsingmicro/scripts/READMD_DEV.md b/third_party/tsingmicro/scripts/READMD_DEV.md index c7310aaea9..bdfd51868c 100644 --- a/third_party/tsingmicro/scripts/READMD_DEV.md +++ b/third_party/tsingmicro/scripts/READMD_DEV.md @@ -57,7 +57,7 @@ apt install openmpi-bin openmpi-doc libopenmpi-dev # 仅开启host侧profile,准确的获取launch的时间,需要重新编译 export USE_HOST_PROFILE=1 # 开启全部profile,包括device的,立即生效,注意因为插桩和打印的影响,此时host侧的launch时间已经不准了 -export USE_PROFILE=1 +export ENABLE_PROFILING=1 #自己定义需要对什么指令做profile,必填 export TRACE_POINTS="__Rdma,__Wdma" diff --git a/third_party/tsingmicro/scripts/base/base_run.sh b/third_party/tsingmicro/scripts/base/base_run.sh index a3fe6020e2..5ea5615aac 100755 --- a/third_party/tsingmicro/scripts/base/base_run.sh +++ b/third_party/tsingmicro/scripts/base/base_run.sh @@ -62,6 +62,7 @@ export TXDA_FALLBACK_CPU_OPS=$txda_fallback_cpu_ops # 非必须的 调试相关 export TRITON_DUMP_PATH=$TRITON/dump export TRITON_ALWAYS_COMPILE=1 +export TRITON_PRINT_AUTOTUNING=1 # dump launch调用的所有参数,包括kernel func调用的参数 export DUMP_KERNEL_ARGS=1 @@ -71,11 +72,13 @@ export PRECISION_PRIORITY=1 export TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1 # export DEBUG=ON -# export USE_PROFILE=1 +# export ENABLE_PROFILING=1 # export USE_HOST_PROFILE=1 +# export TX_LOG_LEVEL=debug # export CUSTOMIZED_IR=test_0.mlir,test_1.mlir # export TRACE_POINTS="__Rdma,__Wdma" +echo "export TX_LOG_LEVEL=$TX_LOG_LEVEL" echo "export TX8_DEPS_ROOT=$TX8_DEPS_ROOT" echo "export LLVM_SYSPATH=$LLVM_SYSPATH" echo "export LLVM_BINARY_DIR=$LLVM_BINARY_DIR" @@ -89,7 +92,7 @@ echo "export TRITON_ALWAYS_COMPILE=$TRITON_ALWAYS_COMPILE" echo "export TXDA_SKIP_OPS=$TXDA_SKIP_OPS" echo "export TXDA_FALLBACK_CPU_OPS=$TXDA_FALLBACK_CPU_OPS" -echo "export USE_PROFILE=$USE_PROFILE" +echo "export ENABLE_PROFILING=$ENABLE_PROFILING" echo "export USE_HOST_PROFILE=$USE_HOST_PROFILE" echo "export CUSTOMIZED_IR=$CUSTOMIZED_IR" echo "export TRACE_POINTS=$TRACE_POINTS" diff --git a/third_party/tsingmicro/scripts/build_tsingmicro.sh b/third_party/tsingmicro/scripts/build_tsingmicro.sh index 0b08149ac5..911320335d 100755 --- a/third_party/tsingmicro/scripts/build_tsingmicro.sh +++ b/third_party/tsingmicro/scripts/build_tsingmicro.sh @@ -9,7 +9,7 @@ project_dir=$(realpath "$script_dir/../../..") if [ -z "${WORKSPACE+x}" ]; then WORKSPACE=$(realpath "$project_dir/..") fi -echo "${realpath}" + TX8_DEPS_ROOT=$WORKSPACE/tx8_deps LLVM=$WORKSPACE/llvm-a66376b0-ubuntu-x64 TRITON=$project_dir @@ -97,9 +97,8 @@ build_triton() { export TRITON_BUILD_WITH_CLANG_LLD=true export TRITON_BUILD_WITH_CCACHE=true - export TRITON_OFFLINE_BUILD=OFF + export TRITON_OFFLINE_BUILD=ON export TRITON_BUILD_PROTON=OFF - export CXXFLAGS="-Wno-dangling-assignment-gsl" echo "export TRITON_OFFLINE_BUILD=$TRITON_OFFLINE_BUILD" echo "export TRITON_BUILD_WITH_CLANG_LLD=$TRITON_BUILD_WITH_CLANG_LLD" @@ -122,7 +121,6 @@ fi export LLVM_SYSPATH=$LLVM export TX8_DEPS_ROOT=$TX8_DEPS_ROOT -export FLAGTREE_BACKEND=tsingmicro # debug # export USE_HOST_PROFILE=1 @@ -131,4 +129,8 @@ export FLAGTREE_BACKEND=tsingmicro echo "export TX8_DEPS_ROOT=$TX8_DEPS_ROOT" echo "export LLVM_SYSPATH=$LLVM_SYSPATH" +# synchronous temporary solution: add waitfinish after every cintrinsic exec +export ENABLE_SYNCHRONOUS_INTRINSIC=1 +echo "export ENABLE_SYNCHRONOUS_INTRINSIC=$ENABLE_SYNCHRONOUS_INTRINSIC" + build_triton diff --git a/third_party/tsingmicro/scripts/build_tx8_deps.sh b/third_party/tsingmicro/scripts/build_tx8_deps.sh index 20f1c16993..c4fd780f80 100755 --- a/third_party/tsingmicro/scripts/build_tx8_deps.sh +++ b/third_party/tsingmicro/scripts/build_tx8_deps.sh @@ -162,11 +162,16 @@ show_help() { echo " default: http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81fw/tx81fw_202512041135_b731cf.tar.gz" echo " ..." echo "" + echo "tx_profiler url:" + echo " eg: http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81-profiling/master/profiling_tool_v5.5.0_release_2025-1124_.tar.gz" + echo " default: http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81-profiling/master/profiling_tool_v5.5.0_release_2025-1124_.tar.gz" + echo " ..." + echo "" echo "example: $0 build_tx8_deps" } # 检查参数数量 -if [ $# -le 1 ]; then +if [ $# -lt 1 ]; then show_help exit 1 fi @@ -253,6 +258,12 @@ if [ $# -ge 2 ]; then tx81fw=$2 fi echo "tx81fw: "$tx81fw + +profile_url=http://172.50.1.66:8082/artifactory/tx8-generic-dev/tx81-profiling/master/profiling_tool_v5.5.0_release_2025-1124_.tar.gz +if [ $# -ge 3 ]; then + profile_url=$3 +fi +echo "profile_url: "$profile_url if [ "x$MODE" == "xbuild_flagtree_tx8_deps" ] || [ "x$MODE" == "xbuild_tx8_deps" ] || [ "x$MODE" == "xbuild_dev" ]; then download_dir=$WORKSPACE/download ######################################################################################## @@ -283,16 +294,10 @@ if [ "x$MODE" == "xbuild_flagtree_tx8_deps" ] || [ "x$MODE" == "xbuild_tx8_deps" fi ######################################################################################## - triton_profiler_dir=$WORKSPACE/triton_profiler - # 因为要编译,不能放到download 里面,否则找到一些一些tx8_deps的路径 - clone_and_checkout "git@gitlab.tsingmicro.com:triton-based-projects/triton_profiler.git" \ - "$WORKSPACE" "branch" "master" "triton_profiler" - - profiler_head=$triton_profiler_dir/profiler/include/profiler.h - if [ ! -f $profiler_head ]; then - echo "error can't find:$profiler_head" - exit - fi + + tx_profiler_dir=$download_dir/tx_profiler + download_and_extract $profile_url \ + "$tx_profiler_dir" "$download_dir" "tx_profiler" ######################################################################################## tx8_depends_dir=$WORKSPACE/tx8_deps @@ -304,13 +309,6 @@ if [ "x$MODE" == "xbuild_flagtree_tx8_deps" ] || [ "x$MODE" == "xbuild_tx8_deps" load_copy_cfg echo -e $PROJECT_INFO > $tx8_depends_dir/version.txt - if [ "x$MODE" == "xbuild_dev" ]; then - copy_files profile_need $tx8_depends_dir - triton_profiler_pkg=$triton_profiler_dir/profile - if [ ! -f $triton_profiler_dir/profile/profiler_commit_id.txt ]; then - bash $triton_profiler_dir/build.sh rebuild - fi - fi copy_files $MODE $tx8_depends_dir pushd $WORKSPACE diff --git a/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh b/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh new file mode 100755 index 0000000000..5fc7272b53 --- /dev/null +++ b/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh @@ -0,0 +1,324 @@ +#!/bin/bash +set -e +##1.下载triton、flaggems代码到/login_home/jenkins_tc/triton目录(CI负责,必须是这个目录,在容器外以root账号) +#sudo -s +#mkdir triton +#cd triton +#git clone "http://192.168.100.107/triton-based-projects/triton" && (cd "triton" && mkdir -p .git/hooks && curl -Lo `git rev-parse --git-dir`/hooks/commit-msg http://192.168.100.107/tools/hooks/commit-msg; chmod +x `git rev-parse --git-dir`/hooks/commit-msg) +#git clone "http://gitlab.tsingmicro.com/triton-based-projects/flaggems.git" -b board-test-base + +##2.创建容器(CI负责) +#docker run -d --name jenkins_tc_triton_ci --network=host --ipc=host --privileged -v /dev:/dev -v /tmp:/tmp -v /lib/modules:/lib/modules -v /sys:/sys -v /login_home/:/login_home/ -w /login_home/jenkins_tc/ hub.tsingmicro.com/tx8/ubuntu/v5.5.0.1030:tsingmicro_release + +##3.进入容器(CI负责) +#docker exec -it jenkins_tc_triton_ci /bin/bash + +##4.在/login_home/jenkins_tc/triton执行业务的ci运行脚本 +#cd triton +#bash triton/third_party/tsingmicro/scripts/ci/run_triton_flaggems_ci_test.sh 0 0 0 ci_ops 1 + + +########################################################################################################################## +## ## +## 业务的ci运行脚本 ## +## param1: skip_install, set 1-skip depends install,set 0-install depends, default 0. ## +## param2: skip_build, set 1-skip triton build, set 0-build triton, default 0. ## +## param3: skip_run, set 1-skip run ci test, set 0-run ci test, default 0. ## +## param4: test_set, set test set name, default 'ci_ops'. ## +## param5: device_count, set device count number, default 1. ## +## ##param: precision_priority, set 1-triton compiler use high precision mode for special ops, default 1. ## +## param6: quick_mode, set 1-quick mode to run flaggems, set 0-normal mode, default 0. ## +## param7: skip_device, set devices that need to be skipped, when they are unavailable, default []. ## +## ## +########################################################################################################################## + +script_path=$(realpath "$0") +echo $script_path +script_dir=$(dirname "$script_path") +echo $script_dir +project_dir=$(realpath "$script_dir/../../../../../") +echo $project_dir +export TRITON_WORKSPACE=$project_dir +skip_install=0 +skip_build=0 +skip_run=0 +test_set=ci_ops +device_count=1 +quick_mode=0 +skip_device= + +precision_priority=1 +tx8_depends_name=tx8_depends_dev_20260112_201902 +torch_txda_name=torch_txda+txops-20251230-03541ed8+71a1e5a +txda_skip_ops="repeat_interleave.self_int,pad,to.dtype,uniform_,sort.values_stable,contiguous,resolve_conj" +txda_fallback_cpu_ops="random_,quantile,_local_scalar_dense,arange,unfold,index,le,all,ge,pad,to,gather_backward,zero_,view_as_real,resolve_neg,embedding_backward,sort,repeat_interleave,rsub,hstack,vstack,min,uniform_,abs,ne,eq,mul,bitwise_and,masked_select,max,ceil,div,gt,lt,sum,scatter,where,resolve_conj,isclose,isfinite,tile,equal,gather,contiguous" + +if [ $# -ge 1 ]; then + skip_install=$1 +fi +if [ $# -ge 2 ]; then + skip_build=$2 +fi +if [ $# -ge 3 ]; then + skip_run=$3 +fi +if [ $# -ge 4 ]; then + test_set=$4 +fi +if [ $# -ge 5 ]; then + device_count=$5 +fi +if [ $# -ge 6 ]; then + quick_mode=$6 +fi +if [ $# -ge 7 ]; then + skip_device=$(echo $7 | tr ',' ' ') +fi +echo "param count:"$# +echo "skip_install:"$skip_install +echo "skip_build:"$skip_build +echo "skip_run:"$skip_run +echo "test_set:"$test_set +echo "device_count:"$device_count +echo "quick_mode:"$quick_mode +echo "skip_device:"$skip_device +echo "precision_priority:"$precision_priority +echo "tx8_depends_name:"$tx8_depends_name +echo "torch_txda_name:"$torch_txda_name +echo "txda_skip_ops:"$txda_skip_ops +echo "txda_fallback_cpu_ops:"$txda_fallback_cpu_ops +##1.下载依赖(triton业务负责) +cd $project_dir +#为了加快ci速度,从提前下载好的位置cp. src位置变后此处要更新 +TRITON_DEPENDS_SRC=/login_home/jenkins_tc/triton +###download llvm(很少变化) +if [ ! -d "./llvm-a66376b0-ubuntu-x64" ]; then + if [ -d $TRITON_DEPENDS_SRC/llvm-a66376b0-ubuntu-x64 ]; then + cp -r $TRITON_DEPENDS_SRC/llvm-a66376b0-ubuntu-x64/ ./ + echo "cp $TRITON_DEPENDS_SRC/llvm-a66376b0-ubuntu-x64 complete!" + else + echo "warning:$TRITON_DEPENDS_SRC/llvm-a66376b0-ubuntu-x64 not exist, use wget to download, maybe very slowly!" + fi +fi + +if [ ! -d "./llvm-a66376b0-ubuntu-x64" ]; then + wget https://toolchain-jfrog.tsingmicro.xyz:443/artifactory/tx8-generic-dev/triton/tools/llvm-a66376b0-ubuntu-x64.tar.gz + if [ $? -eq 0 ]; then + echo "Download llvm complete!" + else + echo "Download llvm fail!!!" + exit -1 + fi + tar -xzvf llvm-a66376b0-ubuntu-x64.tar.gz + rm llvm-a66376b0-ubuntu-x64.tar.gz +fi + +if [ ! -d "./llvm-a66376b0-ubuntu-x64" ]; then + echo "fail: not find llvm!!!" + exit -1 +fi + +###download torch2.7 wheels for offline install(很少变化) +if [ ! -d "./offline_pkgs" ]; then + if [ -d $TRITON_DEPENDS_SRC/offline_pkgs ]; then + cp -r $TRITON_DEPENDS_SRC/offline_pkgs/ ./ + echo "cp $TRITON_DEPENDS_SRC/offline_pkgs complete!" + else + echo "warning:$TRITON_DEPENDS_SRC/offline_pkgs not exist, use wget to download, maybe very slowly!" + fi +fi + +if [ ! -d "./offline_pkgs" ]; then + wget https://toolchain-jfrog.tsingmicro.xyz:443/artifactory/tx8-generic-dev/triton/offline_pkgs/offline_pkgs_v5.3.0.tar.gz + if [ $? -eq 0 ]; then + echo "Download offline package complete!" + else + echo "Download offline package fail!!!" + exit -1 + fi + tar -xzvf offline_pkgs_v5.3.0.tar.gz + rm offline_pkgs_v5.3.0.tar.gz +fi + +if [ ! -d "./offline_pkgs" ]; then + echo "fail: not find offline_pkgs!!!" + exit -1 +fi + +###download tx8_deps(变化频率较高) +if [ ! -e $tx8_depends_name.tar.gz ]; then + if [ -e $TRITON_DEPENDS_SRC/$tx8_depends_name.tar.gz ]; then + cp $TRITON_DEPENDS_SRC/$tx8_depends_name.tar.gz ./ + if [ -d "./tx8_deps" ]; then + rm -rf tx8_deps + fi + tar -xzvf $tx8_depends_name.tar.gz + echo "cp $TRITON_DEPENDS_SRC/$tx8_depends_name.tar.gz complete!" + else + echo "warning:$TRITON_DEPENDS_SRC/$tx8_depends_name.tar.gz not exist, use wget to download, maybe very slowly!" + fi +fi + +if [ ! -e $tx8_depends_name.tar.gz ]; then + if [ -d "./tx8_deps" ]; then + rm -rf tx8_deps + fi + + wget https://toolchain-jfrog.tsingmicro.xyz:443/artifactory/tx8-generic-dev/triton/tx8_depends/$tx8_depends_name.tar.gz + if [ $? -eq 0 ]; then + echo "Download tx8_deps complete!" + else + echo "Download tx8_dpes fail!!!" + exit -1 + fi + tar -xzvf $tx8_depends_name.tar.gz +fi + +if [ ! -d "./tx8_deps" ]; then + echo "fail: not find tx8_deps!!!" + exit -1 +fi + +###download torch_txda(变化频率较高) +if [ ! -e $torch_txda_name.tar.gz ]; then + if [ -e $TRITON_DEPENDS_SRC/$torch_txda_name.tar.gz ]; then + cp $TRITON_DEPENDS_SRC/$torch_txda_name.tar.gz ./ + if [ -d "./pack" ]; then + rm -rf pack + fi + tar -xzvf $torch_txda_name.tar.gz + echo "cp $TRITON_DEPENDS_SRC/$torch_txda_name.tar.gz complete!" + else + echo "warning:$TRITON_DEPENDS_SRC/$torch_txda_name.tar.gz not exist, use wget to download, maybe very slowly!" + fi +fi + +if [ ! -e $torch_txda_name.tar.gz ]; then + if [ -d "./pack" ]; then + rm -rf pack + fi + + wget https://toolchain-jfrog.tsingmicro.xyz:443/artifactory/tx8-generic-dev/torch_txda/$torch_txda_name.tar.gz + if [ $? -eq 0 ]; then + echo "Download torch_txda complete!" + else + echo "Download torch_txda fail!!!" + exit -1 + fi + tar -xzvf $torch_txda_name.tar.gz +fi + +if [ ! -d "./pack" ]; then + echo "fail: not find torch_txda pack!!!" + exit -1 +fi + +##2.安装依赖(triton业务负责) +cd triton +if [ $skip_install -ne 1 ]; then + if [ -d "./.venv" ]; then + rm -rf .venv + fi + python3 -m venv .venv --prompt triton + source .venv/bin/activate + #check python version + python3 --version + bash third_party/tsingmicro/scripts/tools/offline_python_deps.sh -i -r python/requirements.txt -d ../offline_pkgs + if [ $? -eq 0 ]; then + echo "Install compile tool package complete!" + else + echo "Install compile tool package fail!!!" + exit -1 + fi + + bash third_party/tsingmicro/scripts/tools/offline_python_deps.sh -i -r third_party/tsingmicro/scripts/requirements_ts.txt -d ../offline_pkgs + if [ $? -eq 0 ]; then + echo "Install torch package complete!" + else + echo "Install torch package fail!!!" + exit -1 + fi + #check torch version + python3 -c "import torch; print(torch.__version__)" + + PROXY=http://192.168.100.225:8889 + export https_proxy=$PROXY http_proxy=$PROXY all_proxy=$PROXY + apt update + apt install -y lld + apt install ccache + pip install loguru + pip install scipy + unset https_proxy + unset http_proxy + unset all_proxy + + ###install torch_txda(变化频率较高,须随着上述下载名字变化而变化) + txops_wheel=$(find ../pack/ -maxdepth 1 -name "txops*.whl" -print -quit) + torch_txda_wheel=$(find ../pack/ -maxdepth 1 -name "torch_txda*.whl" -print -quit) + pip install $txops_wheel + pip install $torch_txda_wheel +fi + +##3.编译triton(triton业务负责) +if [ $skip_build -ne 1 ]; then + bash ./third_party/tsingmicro/scripts/build_tsingmicro.sh + if [ $? -eq 0 ]; then + echo "Build triton complete!" + else + echo "Build triton fail!!!" + exit -1 + fi +fi +##4.运行测试(triton业务负责) +#triton系统相关环境变量 +TX8_DEPS_ROOT=$TRITON_WORKSPACE/tx8_deps +LLVM=$TRITON_WORKSPACE/llvm-a66376b0-ubuntu-x64 +export TX8_DEPS_ROOT=$TX8_DEPS_ROOT +export LLVM_SYSPATH=$LLVM +export LLVM_BINARY_DIR=$LLVM/bin +export PYTHONPATH=$LLVM/python_packages/mlir_core:$PYTHONPATH +export LD_LIBRARY_PATH=$TX8_DEPS_ROOT/lib:$LD_LIBRARY_PATH +#export TRITON_DUMP_PATH=$TRITON_WORKSPACE/dump +export TRITON_ALWAYS_COMPILE=1 +export TRITON_PRINT_AUTOTUNING=1 +#测试任务相关环境变量 +export JSON_FILE_PATH=$project_dir/flaggems/tests +#export TX8_DEVICES_COUNT=$device_count +export PRECISION_PRIORITY=$precision_priority +export TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1 +export TXDA_SKIP_OPS=$txda_skip_ops +export TXDA_FALLBACK_CPU_OPS=$txda_fallback_cpu_ops + +echo "TX8_DEPS_ROOT="$TX8_DEPS_ROOT +echo "LLVM_SYSPATH="$LLVM_SYSPATH +echo "LLVM_BINARY_DIR="$LLVM_BINARY_DIR +echo "PYTHONPATH="$PYTHONPATH +echo "LD_LIBRARY_PATH="$LD_LIBRARY_PATH +#echo "TRITON_DUMP_PATH="$TRITON_DUMP_PATH +echo "TRITON_ALWAYS_COMPILE="$TRITON_ALWAYS_COMPILE +echo "JSON_FILE_PATH="$JSON_FILE_PATH +#echo "TX8_DEVICES_COUNT="$TX8_DEVICES_COUNT +echo "PRECISION_PRIORITY="$PRECISION_PRIORITY +echo "TRITON_ALLOW_NON_CONSTEXPR_GLOBALS="$TRITON_ALLOW_NON_CONSTEXPR_GLOBALS +echo "TXDA_SKIP_OPS="$TXDA_SKIP_OPS +echo "TXDA_FALLBACK_CPU_OPS="$TXDA_FALLBACK_CPU_OPS + +if [ $skip_run -ne 1 ]; then + if [ $skip_install -eq 1 ]; then + source .venv/bin/activate + fi + + cd .. + if [ $quick_mode -eq 1 ]; then + python ./flaggems/tests/test_flag_gems_ci.py --test_set $test_set --device_count $device_count --skip_device $skip_device --quick + else + python ./flaggems/tests/test_flag_gems_ci.py --test_set $test_set --device_count $device_count --skip_device $skip_device + fi + + if [ $? -eq 0 ]; then + echo "Run test complete!" + else + echo "Run test fail!!!" + exit -1 + fi +fi diff --git a/third_party/tsingmicro/scripts/copy_config.conf b/third_party/tsingmicro/scripts/copy_config.conf index 4b9f7281ca..8d7c1d53ba 100644 --- a/third_party/tsingmicro/scripts/copy_config.conf +++ b/third_party/tsingmicro/scripts/copy_config.conf @@ -5,14 +5,12 @@ file:download/triton-tx8fw/**/lib/libinstr_tx81.a,lib file:download/triton-tx8fw/**/lib/liblibc_stub.a,lib file:download/triton-tx8fw/**/include/components/oplib_tx81/riscv/riscv/include/lib_log.h,include dir:download/triton-tx8fw/**/tx81-intrisic/instr_tx81/include -file:triton_profiler/profiler/include/profiler.h,include dir:download/tx8fw-xuantie-sdk/Xuantie-900-gcc-elf-newlib-x86_64-V2.10.2 file:download/*.pdf [block3:build_dev] dir:triton_profiler/profile/lib -file:triton_profiler/profile/profiler_commit_id.txt -file:triton_profiler/analyer/analyer.py,scripts +dir:download/tx_profiler/profiling_tool*/*,profiling_tool [block4:profile_need] dir:download/triton-tx8fw/**/include/*,include diff --git a/third_party/tsingmicro/scripts/publish/README.md b/third_party/tsingmicro/scripts/publish/README.md index 603a9a3dc1..30c5d226d9 100644 --- a/third_party/tsingmicro/scripts/publish/README.md +++ b/third_party/tsingmicro/scripts/publish/README.md @@ -25,7 +25,8 @@ - torch_txda 20251212: [制品库下载链接](http://172.50.1.66:8082/artifactory/tx8-generic-dev/torch_txda/torch_txda%2Btxops-20251212-eb33daf9%2B8d7d6e7.tar.gz) - docker 5.6.0daily: hub.tsingmicro.com/tx8/ubuntu/daily:tsingmicro_release_patch.251219130722 - tx8_deps 20251218: [制品库下载链接](http://172.50.1.66:8082/artifactory/tx8-generic-dev/triton/tx8_depends/tx8_depends_dev_20251218_164108.tar.gz) - - torch_txda 20251219: [制品库下载链接](http://172.50.1.66:8082/artifactory/tx8-generic-dev/torch_txda/torch_txda%2Btxops-20251219-c028d111%2Be0de74c.tar.gz) + - torch_txda 20251230: [制品库下载链接](http://172.50.1.66:8082/artifactory/tx8-generic-dev/torch_txda/torch_txda%2Btxops-20251230-03541ed8%2B71a1e5a.tar.gz) +注意: tx8_deps、torch_txda迭代较快, 每个版本最终配套的以Triton项目发布为准. ## 版本运行说明 diff --git a/third_party/tsingmicro/scripts/publish/build_wheel.sh b/third_party/tsingmicro/scripts/publish/build_wheel.sh index 9e70f18a35..e519e5e801 100755 --- a/third_party/tsingmicro/scripts/publish/build_wheel.sh +++ b/third_party/tsingmicro/scripts/publish/build_wheel.sh @@ -92,6 +92,7 @@ PROXY=http://192.168.100.225:8889 export https_proxy=$PROXY http_proxy=$PROXY all_proxy=$PROXY apt install -y lld apt install ccache +apt install git pip3 install scikit_build_core #flaggems需要,也需要torch #unset https_proxy unset http_proxy @@ -115,6 +116,8 @@ export TRITON_OFFLINE_BUILD=ON export TRITON_BUILD_PROTON=OFF export LLVM_SYSPATH=$LLVM export TX8_DEPS_ROOT=$TX8_DEPS_ROOT +# synchronous temporary solution: add waitfinish after every cintrinsic exec +export ENABLE_SYNCHRONOUS_INTRINSIC=1 cd python python3 -m pip wheel . --no-build-isolation -v --verbos diff --git a/third_party/tsingmicro/scripts/publish/publish.sh b/third_party/tsingmicro/scripts/publish/publish.sh index 417caa1754..0234003167 100755 --- a/third_party/tsingmicro/scripts/publish/publish.sh +++ b/third_party/tsingmicro/scripts/publish/publish.sh @@ -4,7 +4,7 @@ set -e ########################################################################################################################## ## ## ## Triton版本包制作脚本 ## -## 在docker容器外triton代码目录上一级目录下执行, 制作triton完整版本包. ## +## 在docker容器内triton代码目录上一级目录下执行, 制作triton完整版本包. ## ## 执行该脚本前先执行build_wheel.sh生成triton、flaggems wheel包. ## ## ## ########################################################################################################################## @@ -88,6 +88,7 @@ fi cp $script_dir/README.md $publish_dir cp $script_dir/install.sh $publish_dir/scripts cp $script_dir/run_tsingmicro.sh $publish_dir/scripts +cp $script_dir/run_flaggems_on_multicards.sh $publish_dir/scripts cp $script_dir/../base/base_run.sh $publish_dir/scripts cp $script_dir/../requirements_ts.txt $publish_dir/scripts cp $script_dir/../tools/offline_python_deps.sh $publish_dir/scripts diff --git a/third_party/tsingmicro/scripts/publish/run_flaggems_on_multicards.sh b/third_party/tsingmicro/scripts/publish/run_flaggems_on_multicards.sh new file mode 100755 index 0000000000..a2f7b62d81 --- /dev/null +++ b/third_party/tsingmicro/scripts/publish/run_flaggems_on_multicards.sh @@ -0,0 +1,94 @@ +#!/bin/bash +set -e +##.在docker容器内版本包路径下执行 +#bash scripts/run_flaggems_on_multicards.sh ci_ops 1 + + +########################################################################################################################## +## ## +## 在多卡上并行运行Triton算子测试脚本 ## +## param1: test_set, set test set name, default 'ci_ops'. ## +## param2: device_count, set device count number, default 1. ## +## ##param: precision_priority, set 1-triton compiler use high precision mode for special ops, default 1. ## +## param3: quick_mode, set 1-quick mode to run flaggems, set 0-normal mode, default 0. ## +## param4: skip_device, set devices that need to be skipped, when they are unavailable, default []. ## +## ## +########################################################################################################################## + +script_path=$(realpath "$0") +echo $script_path +script_dir=$(dirname "$script_path") +echo $script_dir +project_dir=$(realpath "$script_dir/../") +echo $project_dir +export TRITON_WORKSPACE=$project_dir +test_set=ci_ops +device_count=1 +quick_mode=0 +skip_device= +precision_priority=1 +txda_skip_ops="repeat_interleave.self_int,pad,to.dtype,uniform_,sort.values_stable,contiguous,resolve_conj" +txda_fallback_cpu_ops="random_,quantile,_local_scalar_dense,arange,unfold,index,le,all,ge,pad,to,gather_backward,zero_,view_as_real,resolve_neg,embedding_backward,sort,repeat_interleave,rsub,hstack,vstack,min,uniform_,abs,ne,eq,mul,bitwise_and,masked_select,max,ceil,div,gt,lt,sum,scatter,where,resolve_conj,isclose,isfinite,tile,equal,gather,contiguous" + +if [ $# -ge 1 ]; then + test_set=$1 +fi +if [ $# -ge 2 ]; then + device_count=$2 +fi +if [ $# -ge 3 ]; then + quick_mode=$3 +fi +if [ $# -ge 4 ]; then + skip_device=$(echo $4 | tr ',' ' ') +fi +echo "param count:"$# +echo "test_set:"$test_set +echo "device_count:"$device_count +echo "quick_mode:"$quick_mode +echo "skip_device:"$skip_device +echo "precision_priority:"$precision_priority +echo "txda_skip_ops:"$txda_skip_ops +echo "txda_fallback_cpu_ops:"$txda_fallback_cpu_ops + +#triton系统相关环境变量 +TX8_DEPS_ROOT=$project_dir/tx8_deps +LLVM=$project_dir/llvm-a66376b0-ubuntu-x64 +export TX8_DEPS_ROOT=$TX8_DEPS_ROOT +export LLVM_SYSPATH=$LLVM +export LLVM_BINARY_DIR=$LLVM/bin +export PYTHONPATH=$LLVM/python_packages/mlir_core:$PYTHONPATH +export LD_LIBRARY_PATH=$TX8_DEPS_ROOT/lib:$LD_LIBRARY_PATH +export TRITON_ALWAYS_COMPILE=1 +#测试任务相关环境变量 +export JSON_FILE_PATH=$project_dir/flaggems_tests +export PRECISION_PRIORITY=$precision_priority +export TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1 +export TXDA_SKIP_OPS=$txda_skip_ops +export TXDA_FALLBACK_CPU_OPS=$txda_fallback_cpu_ops + +echo "TX8_DEPS_ROOT="$TX8_DEPS_ROOT +echo "LLVM_SYSPATH="$LLVM_SYSPATH +echo "LLVM_BINARY_DIR="$LLVM_BINARY_DIR +echo "PYTHONPATH="$PYTHONPATH +echo "LD_LIBRARY_PATH="$LD_LIBRARY_PATH +echo "TRITON_ALWAYS_COMPILE="$TRITON_ALWAYS_COMPILE +echo "JSON_FILE_PATH="$JSON_FILE_PATH +echo "PRECISION_PRIORITY="$PRECISION_PRIORITY +echo "TRITON_ALLOW_NON_CONSTEXPR_GLOBALS="$TRITON_ALLOW_NON_CONSTEXPR_GLOBALS +echo "TXDA_SKIP_OPS="$TXDA_SKIP_OPS +echo "TXDA_FALLBACK_CPU_OPS="$TXDA_FALLBACK_CPU_OPS + +source $project_dir/triton/.venv/bin/activate +if [ $quick_mode -eq 1 ]; then + python3 $project_dir/flaggems_tests/test_flag_gems_ci.py --test_set $test_set --device_count $device_count --skip_device $skip_device --quick +else + python3 $project_dir/flaggems_tests/test_flag_gems_ci.py --test_set $test_set --device_count $device_count --skip_device $skip_device +fi + +if [ $? -eq 0 ]; then + echo "Run test complete!" +else + echo "Run test fail!!!" + exit -1 +fi