Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/mlc_llm/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def _check_system_lib_prefix(prefix: str) -> str:
default="auto",
help=HELP["host"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--enable-subgroups",
action="store_true",
help=HELP["enable_subgroups"],
)
parser.add_argument(
"--opt",
type=OptimizationFlags.from_str,
Expand Down Expand Up @@ -117,7 +122,11 @@ def _check_system_lib_prefix(prefix: str) -> str:
help=HELP["debug_dump"] + " (default: %(default)s)",
)
parsed = parser.parse_args(argv)
target, build_func = detect_target_and_host(parsed.device, parsed.host)
target, build_func = detect_target_and_host(
parsed.device,
parsed.host,
enable_subgroups=parsed.enable_subgroups,
)
parsed.model_type = detect_model_type(parsed.model_type, parsed.model)
parsed.quantization = detect_quantization(parsed.quantization, parsed.model)
parsed.system_lib_prefix = detect_system_lib_prefix(
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/interface/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
""".strip(),
"device_compile": """
The GPU device to compile the model to. If not set, it is inferred from GPUs available locally.
""".strip(),
"enable_subgroups": """
Enable WebGPU subgroups in codegen. This only applies to WebGPU targets and will set
supports_subgroups accordingly.
""".strip(),
"device_quantize": """
The device used to do quantization such as "cuda" or "cuda:0". Will detect from local available GPUs
Expand Down
22 changes: 21 additions & 1 deletion python/mlc_llm/support/auto_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
BuildFunc = Callable[[IRModule, "CompileArgs", Pass], None]


def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[Target, BuildFunc]:
def detect_target_and_host(
target_hint: str,
host_hint: str = "auto",
enable_subgroups: Optional[bool] = None,
) -> Tuple[Target, BuildFunc]:
"""Detect the configuration for the target device and its host, for example, target GPU and
the host CPU.

Expand All @@ -43,6 +47,7 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T
target, build_func = _detect_target_gpu(target_hint)
if target.host is None:
target = Target(target, host=_detect_target_host(host_hint))
target = _apply_webgpu_subgroups(target, enable_subgroups)
if target.kind.name == "cuda":
# Enable thrust for CUDA
target_dict = dict(target.export())
Expand All @@ -61,6 +66,21 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T
return target, build_func


def _apply_webgpu_subgroups(target: Target, enable_subgroups: Optional[bool]) -> Target:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test for this behavior (enable_subgroups flag causes target_dict to have a supports_subgroups key that is set to True)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also non-webgpu target + enable_subgroups=True leads to unchanged target, and any target + enable_subgroups = False leads to unchanged target

if enable_subgroups is None:
return target
if target.kind.name != "webgpu":
if enable_subgroups:
logger.warning(
"--enable-subgroups is only supported for WebGPU targets; ignoring for %s",
target.kind.name,
)
return target
target_dict = dict(target.export())
target_dict["supports_subgroups"] = bool(enable_subgroups)
return Target(target_dict)


def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]:
if hint in ["iphone", "android", "webgpu", "mali", "opencl"]:
hint += ":generic"
Expand Down
Loading