|
8 | 8 | load("@bazel_skylib//lib:paths.bzl", "paths") |
9 | 9 | load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") |
10 | 10 |
|
| 11 | +IREE_INPUT_TORCH_ENV_KEY = "IREE_INPUT_TORCH" |
11 | 12 | CUDA_TOOLKIT_ROOT_ENV_KEY = "IREE_CUDA_TOOLKIT_ROOT" |
12 | 13 |
|
13 | 14 | # Our CI docker images use a stripped down CUDA directory tree in some |
@@ -67,6 +68,78 @@ cuda_auto_configure = repository_rule( |
67 | 68 | }, |
68 | 69 | ) |
69 | 70 |
|
| 71 | +def torch_mlir_auto_configure_impl(repository_ctx): |
| 72 | + """Conditionally configures torch-mlir based on IREE_INPUT_TORCH env var.""" |
| 73 | + env = repository_ctx.os.environ |
| 74 | + iree_repo_alias = repository_ctx.attr.iree_repo_alias |
| 75 | + enabled = env.get(IREE_INPUT_TORCH_ENV_KEY, "OFF").upper() in ["ON", "TRUE", "1", "YES"] |
| 76 | + |
| 77 | + if enabled: |
| 78 | + # Run torch-mlir's configure to create the overlay. |
| 79 | + # We need to find the torch-mlir source and run its overlay script. |
| 80 | + torch_mlir_path = repository_ctx.path( |
| 81 | + Label("%s//:third_party/torch-mlir/CMakeLists.txt" % iree_repo_alias), |
| 82 | + ).dirname |
| 83 | + bazel_path = torch_mlir_path.get_child("utils").get_child("bazel") |
| 84 | + overlay_path = bazel_path.get_child("torch-mlir-overlay") |
| 85 | + script_path = bazel_path.get_child("overlay_directories.py") |
| 86 | + |
| 87 | + python_bin = repository_ctx.which("python3") |
| 88 | + if not python_bin: |
| 89 | + python_bin = repository_ctx.which("python") |
| 90 | + if not python_bin: |
| 91 | + fail("Failed to find python3 binary for torch-mlir configuration") |
| 92 | + |
| 93 | + cmd = [ |
| 94 | + python_bin, |
| 95 | + script_path, |
| 96 | + "--src", |
| 97 | + torch_mlir_path, |
| 98 | + "--overlay", |
| 99 | + overlay_path, |
| 100 | + "--target", |
| 101 | + ".", |
| 102 | + ] |
| 103 | + exec_result = repository_ctx.execute(cmd, timeout = 60) |
| 104 | + |
| 105 | + if exec_result.return_code != 0: |
| 106 | + fail(("Failed to configure torch-mlir: '{cmd}'\n" + |
| 107 | + "Exited with code {return_code}\n" + |
| 108 | + "stdout:\n{stdout}\n" + |
| 109 | + "stderr:\n{stderr}\n").format( |
| 110 | + cmd = " ".join([str(arg) for arg in cmd]), |
| 111 | + return_code = exec_result.return_code, |
| 112 | + stdout = exec_result.stdout, |
| 113 | + stderr = exec_result.stderr, |
| 114 | + )) |
| 115 | + else: |
| 116 | + # Create stub repository when torch-mlir is disabled. |
| 117 | + repository_ctx.file( |
| 118 | + "BUILD.bazel", |
| 119 | + content = """# Stub: torch-mlir disabled (IREE_INPUT_TORCH != ON) |
| 120 | +package(default_visibility = ["//visibility:public"]) |
| 121 | +
|
| 122 | +# Provide empty targets that dependent code can reference. |
| 123 | +# These will fail at build time if actually used. |
| 124 | +""", |
| 125 | + ) |
| 126 | + |
| 127 | +torch_mlir_auto_configure = repository_rule( |
| 128 | + environ = [IREE_INPUT_TORCH_ENV_KEY], |
| 129 | + implementation = torch_mlir_auto_configure_impl, |
| 130 | + local = True, |
| 131 | + attrs = { |
| 132 | + "iree_repo_alias": attr.string(default = "@iree_core"), |
| 133 | + }, |
| 134 | +) |
| 135 | + |
| 136 | +def configure_iree_torch_mlir_deps(iree_repo_alias = None): |
| 137 | + maybe( |
| 138 | + torch_mlir_auto_configure, |
| 139 | + name = "torch-mlir", |
| 140 | + iree_repo_alias = iree_repo_alias, |
| 141 | + ) |
| 142 | + |
70 | 143 | def configure_iree_cuda_deps(iree_repo_alias = None): |
71 | 144 | maybe( |
72 | 145 | cuda_auto_configure, |
|
0 commit comments