Skip to content

Commit 3400776

Browse files
Berkkirikddirren
authored andcommitted
feat: auto-detect Apple Silicon (MPS) and keep Triton CUDA-only
Extends the --device auto resolution from openai#17 to include Apple Silicon (MPS) so Mac users get GPU acceleration by default instead of falling back to CPU. Two coordinated changes make this safe: 1. opf/_common/device.py — "auto" now picks cuda > mps > cpu. Each fallback emits an info line on stderr so the user always knows which backend was selected. 2. opf/_model/model.py — the Triton-backed MoE kernels are CUDA-only (Triton does not target Metal). Previously the default enabled Triton on any non-CPU device, so trying mps crashed once the MoE layer was hit. Narrow the auto-enable to device.type == "cuda"; mps and cpu both fall back to the torch-ops path unless the user explicitly sets OPF_MOE_TRITON=1. 3. opf/_train/runner.py — mirror the same CUDA-only gate when setting OPF_MOE_TRITON=1 on behalf of the user (previously set it for any non-CPU device, which would silently enable Triton on mps). 4. opf/_cli/common.py — expand --device help text to list the full backend order (cuda > mps > cpu). Verified on macOS (Apple Silicon, Python 3.14, torch 2.11): - resolve_device("auto") → mps (with stderr info line) - resolve_device("mps") → mps - resolve_device("cpu") → cpu - resolve_device("cuda") → returns cuda device (still fails loudly at tensor alloc when the user explicitly asks for it — unchanged) Low-level MPS op sanity check passed for embedding, attention-like matmul/softmax, log_softmax, topk, argsort, bincount — all the ops the inference path relies on. Fixes openai#21
1 parent f7f00ca commit 3400776

4 files changed

Lines changed: 67 additions & 7 deletions

File tree

opf/_cli/common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ def add_device_arg(parser: object) -> None:
5555
parser.add_argument(
5656
"--device",
5757
type=str,
58-
default="cuda",
59-
help="Device to run on",
58+
default="auto",
59+
help=(
60+
"Device to run on. 'auto' (default) picks the best available "
61+
"backend: cuda > mps (Apple Silicon) > cpu. Pass an explicit "
62+
"value like 'cuda', 'mps', or 'cpu' to override or to get a "
63+
"loud error when the requested backend is unavailable."
64+
),
6065
)
6166

6267

opf/_common/device.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Device-name resolution helpers shared by CLI entrypoints."""
2+
3+
from __future__ import annotations
4+
5+
import sys
6+
7+
import torch
8+
9+
AUTO_DEVICE: str = "auto"
10+
11+
12+
def _mps_is_available() -> bool:
13+
"""Return True when the current PyTorch build supports Apple Metal (MPS)."""
14+
backend = getattr(torch.backends, "mps", None)
15+
if backend is None:
16+
return False
17+
is_available = getattr(backend, "is_available", None)
18+
if is_available is None:
19+
return False
20+
try:
21+
return bool(is_available())
22+
except Exception:
23+
return False
24+
25+
26+
def resolve_device(device_name: str) -> torch.device:
27+
"""Resolve a user-supplied device name into a concrete ``torch.device``.
28+
29+
``"auto"`` selects the best available device in this order: CUDA (NVIDIA
30+
GPU) > MPS (Apple Silicon GPU) > CPU. Any other value is passed through
31+
to ``torch.device`` as-is so that explicit requests like ``"cuda"`` or
32+
``"mps"`` still fail loudly when the underlying backend is unavailable.
33+
"""
34+
if device_name == AUTO_DEVICE:
35+
if torch.cuda.is_available():
36+
return torch.device("cuda")
37+
if _mps_is_available():
38+
print(
39+
"info: no CUDA device detected; using Apple Metal (MPS).",
40+
file=sys.stderr,
41+
flush=True,
42+
)
43+
return torch.device("mps")
44+
print(
45+
"info: no CUDA or MPS device detected; falling back to CPU "
46+
"(pass --device cuda or --device mps to override).",
47+
file=sys.stderr,
48+
flush=True,
49+
)
50+
return torch.device("cpu")
51+
return torch.device(device_name)

opf/_model/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,8 +750,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
750750
expert_indices = experts.indices
751751
expert_weights = expert_weights / self.experts_per_token
752752
experts_per_token_eff = self.experts_per_token
753-
not_running_on_cpu = t.device.type != "cpu"
754-
use_triton = get_env_bool("OPF_MOE_TRITON", default=not_running_on_cpu)
753+
# Triton kernels are CUDA-only; auto-enable only on CUDA devices. MPS
754+
# and CPU fall back to the torch-ops path unless the user explicitly
755+
# opts in via OPF_MOE_TRITON=1.
756+
is_cuda_device = t.device.type == "cuda"
757+
use_triton = get_env_bool("OPF_MOE_TRITON", default=is_cuda_device)
755758
if use_triton:
756759
_require_triton()
757760

opf/_train/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,10 @@ def main(argv: Sequence[str] | None = None, *, prog: str | None = None) -> int:
590590
checkpoint = resolve_checkpoint_path(args.checkpoint)
591591
device = torch.device(args.device)
592592

593-
# Default to Triton-backed MoE kernels on non-CPU devices unless callers
594-
# explicitly opt out. CPU uses torch ops by default so Triton stays optional.
595-
if device.type != "cpu":
593+
# Default to Triton-backed MoE kernels on CUDA devices unless callers
594+
# explicitly opt out. CPU and MPS use torch ops by default so Triton
595+
# stays CUDA-only (the kernels don't run on Metal).
596+
if device.type == "cuda":
596597
os.environ.setdefault("OPF_MOE_TRITON", "1")
597598

598599
base_config = _load_checkpoint_config(checkpoint)

0 commit comments

Comments
 (0)