diff --git a/.github/workflows/build_kernel_macos.yaml b/.github/workflows/build_kernel_macos.yaml index 4f95c107..a245dc56 100644 --- a/.github/workflows/build_kernel_macos.yaml +++ b/.github/workflows/build_kernel_macos.yaml @@ -9,23 +9,44 @@ on: jobs: build: - name: Build kernel - runs-on: macos-26 + name: Build and test kernel (${{ matrix.os }}) + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - os: macos-14-xlarge + xcode: "/Applications/Xcode_15.4.app" + # macOS 14 is best-effort: builds work but MPS tests may OOM + # on runners with limited unified memory. + allow-failure: true + - os: macos-15-xlarge + xcode: "/Applications/Xcode_16.2.app" + - os: macos-26-xlarge + xcode: "/Applications/Xcode_26.0.app" + continue-on-error: ${{ matrix.allow-failure || false }} steps: - name: "Select Xcode" - run: sudo xcrun xcode-select -s /Applications/Xcode_26.0.app + run: sudo xcrun xcode-select -s ${{ matrix.xcode }} - name: "Install Metal Toolchain" + if: matrix.os == 'macos-26-xlarge' run: xcodebuild -downloadComponent metalToolchain - uses: actions/checkout@v6 - uses: cachix/install-nix-action@v31 + with: + extra_nix_config: | + sandbox = relaxed - uses: cachix/cachix-action@v16 with: name: huggingface #authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}" - # For now we only test that there are no regressions in building macOS - # kernels. Also run tests once we have a macOS runner. + - name: Build relu kernel run: ( cd builder/examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L ) + - name: Test relu kernel + run: ( cd builder/examples/relu && nix develop .\#test --command pytest tests/ -v ) - name: Build relu metal cpp kernel run: ( cd builder/examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L ) + - name: Test relu metal cpp kernel + run: ( cd builder/examples/relu-metal-cpp && nix develop .\#test --command pytest tests/ -v ) diff --git a/build2cmake/Cargo.lock b/build2cmake/Cargo.lock index ef1cf404..a6ffa3ea 100644 --- a/build2cmake/Cargo.lock +++ b/build2cmake/Cargo.lock @@ -80,7 +80,7 @@ checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "build2cmake" -version = "0.12.2-dev0" +version = "0.13.0-dev0" dependencies = [ "base32", "clap", diff --git a/build2cmake/Cargo.toml b/build2cmake/Cargo.toml index d5ff742d..d19c9637 100644 --- a/build2cmake/Cargo.toml +++ b/build2cmake/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "build2cmake" -version = "0.12.2-dev0" +version = "0.13.0-dev0" edition = "2021" description = "Generate CMake files for kernel-builder projects" homepage = "https://github.com/huggingface/kernel-builder" diff --git a/build2cmake/src/templates/metal/compile-metal.cmake b/build2cmake/src/templates/metal/compile-metal.cmake index 50d44a2d..32699b25 100644 --- a/build2cmake/src/templates/metal/compile-metal.cmake +++ b/build2cmake/src/templates/metal/compile-metal.cmake @@ -1,24 +1,37 @@ # Metal shader compilation function function(compile_metal_shaders TARGET_NAME METAL_SOURCES EXTRA_INCLUDE_DIRS) if(NOT DEFINED METAL_TOOLCHAIN) + # Try the separate Metal toolchain first (macOS 26+ with downloadable component) execute_process( COMMAND "xcodebuild" "-showComponent" "MetalToolchain" OUTPUT_VARIABLE FIND_METAL_OUT RESULT_VARIABLE FIND_METAL_ERROR_CODE - ERROR_VARIABLE FIND_METAL_STDERR OUTPUT_STRIP_TRAILING_WHITESPACE) - if(NOT FIND_METAL_ERROR_CODE EQUAL 0) - message(FATAL_ERROR "${ERR_MSG}: ${FIND_METAL_STDERR}") + if(FIND_METAL_ERROR_CODE EQUAL 0) + string(REGEX MATCH "Toolchain Search Path: ([^\n]+)" MATCH_RESULT "${FIND_METAL_OUT}") + set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain") + else() + # Fall back to the default Xcode toolchain (macOS 14/15 bundle metal in Xcode) + execute_process( + COMMAND "xcode-select" "-p" + OUTPUT_VARIABLE XCODE_DEV_DIR + RESULT_VARIABLE XCODE_SELECT_ERROR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(XCODE_SELECT_ERROR EQUAL 0) + set(METAL_TOOLCHAIN "${XCODE_DEV_DIR}/Toolchains/XcodeDefault.xctoolchain") + else() + message(FATAL_ERROR "Cannot find Metal toolchain. On macOS 26+, use: xcodebuild -downloadComponent metalToolchain") + endif() endif() - - # Extract the Toolchain Search Path value and append Metal.xctoolchain - string(REGEX MATCH "Toolchain Search Path: ([^\n]+)" MATCH_RESULT "${FIND_METAL_OUT}") - set(METAL_TOOLCHAIN "${CMAKE_MATCH_1}/Metal.xctoolchain") endif() - # Set Metal compiler flags - set(METAL_FLAGS "-std=metal4.0" "-O2") + # Set Metal compiler flags. + # metal3.1 → air64_v26, macOS 14+ + # metal3.2 → air64_v27, macOS 15+ + # metal4.0 → air64_v28, macOS 26+ + set(METAL_FLAGS "-std=metal3.1" "-O2") # Output directory for compiled metallib set(METALLIB_OUTPUT_DIR "${CMAKE_BINARY_DIR}/metallib") diff --git a/builder/examples/extra-data/relu_metal/relu.mm b/builder/examples/extra-data/relu_metal/relu.mm index 7636737b..1c195c70 100644 --- a/builder/examples/extra-data/relu_metal/relu.mm +++ b/builder/examples/extra-data/relu_metal/relu.mm @@ -1,3 +1,4 @@ +#include #include #import @@ -18,8 +19,10 @@ torch::Tensor &dispatchReluKernel(torch::Tensor const &input, torch::Tensor &output) { @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get MPS stream"); + id device = stream->device(); int numThreads = input.numel(); // Load the embedded Metal library from memory @@ -44,14 +47,12 @@ error:&error]; TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String); - id commandBuffer = torch::mps::get_command_buffer(); - TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); - - dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); - - dispatch_sync(serialQueue, ^() { - id computeEncoder = - [commandBuffer computeCommandEncoder]; + // Use stream->commandEncoder() to properly integrate with PyTorch's + // MPS encoder lifecycle (kernel coalescing). Creating encoders directly + // via [commandBuffer computeCommandEncoder] bypasses this and crashes + // when the kernel is called twice in sequence. + dispatch_sync(stream->queue(), ^() { + id computeEncoder = stream->commandEncoder(); TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); [computeEncoder setComputePipelineState:reluPSO]; @@ -72,11 +73,9 @@ [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; - - [computeEncoder endEncoding]; - - torch::mps::commit(); }); + + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); } return output; diff --git a/builder/examples/relu/relu_metal/relu.mm b/builder/examples/relu/relu_metal/relu.mm index 7636737b..1c195c70 100644 --- a/builder/examples/relu/relu_metal/relu.mm +++ b/builder/examples/relu/relu_metal/relu.mm @@ -1,3 +1,4 @@ +#include #include #import @@ -18,8 +19,10 @@ torch::Tensor &dispatchReluKernel(torch::Tensor const &input, torch::Tensor &output) { @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get MPS stream"); + id device = stream->device(); int numThreads = input.numel(); // Load the embedded Metal library from memory @@ -44,14 +47,12 @@ error:&error]; TORCH_CHECK(reluPSO, error.localizedDescription.UTF8String); - id commandBuffer = torch::mps::get_command_buffer(); - TORCH_CHECK(commandBuffer, "Failed to retrieve command buffer reference"); - - dispatch_queue_t serialQueue = torch::mps::get_dispatch_queue(); - - dispatch_sync(serialQueue, ^() { - id computeEncoder = - [commandBuffer computeCommandEncoder]; + // Use stream->commandEncoder() to properly integrate with PyTorch's + // MPS encoder lifecycle (kernel coalescing). Creating encoders directly + // via [commandBuffer computeCommandEncoder] bypasses this and crashes + // when the kernel is called twice in sequence. + dispatch_sync(stream->queue(), ^() { + id computeEncoder = stream->commandEncoder(); TORCH_CHECK(computeEncoder, "Failed to create compute command encoder"); [computeEncoder setComputePipelineState:reluPSO]; @@ -72,11 +73,9 @@ [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; - - [computeEncoder endEncoding]; - - torch::mps::commit(); }); + + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); } return output; diff --git a/builder/lib/cache.nix b/builder/lib/cache.nix index 3211a62f..105113b4 100644 --- a/builder/lib/cache.nix +++ b/builder/lib/cache.nix @@ -22,6 +22,7 @@ ++ allOutputs build2cmake ++ allOutputs kernel-abi-check ++ allOutputs python3Packages.kernels + ++ allOutputs python3Packages.tvm-ffi ++ lib.optionals stdenv.hostPlatform.isLinux (allOutputs stdenvGlibc_2_27) ); buildSetLinkFarm = buildSet: pkgs.linkFarm buildSet.torch.variant (buildSetOutputs buildSet); diff --git a/builder/lib/deps.nix b/builder/lib/deps.nix index 9163cb80..db2b91a9 100644 --- a/builder/lib/deps.nix +++ b/builder/lib/deps.nix @@ -35,7 +35,7 @@ let pythonDeps = let - depsJson = builtins.fromJSON (builtins.readFile ../build2cmake/src/python_dependencies.json); + depsJson = builtins.fromJSON (builtins.readFile ../../build2cmake/src/python_dependencies.json); # Map the Nix package names to actual Nix packages. updatePackage = _name: dep: dep // { nix = map (pkg: pkgs.python3.pkgs.${pkg}) dep.nix; }; updateBackend = _backend: backendDeps: lib.mapAttrs updatePackage backendDeps; diff --git a/builder/lib/torch-extension/arch.nix b/builder/lib/torch-extension/arch.nix index 0bf02efc..124e5931 100644 --- a/builder/lib/torch-extension/arch.nix +++ b/builder/lib/torch-extension/arch.nix @@ -86,9 +86,11 @@ let # On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders. # It's not supported by the nixpkgs shim. xcrunHost = writeScriptBin "xcrunHost" '' - # Use system SDK for Metal files. + # Use system SDK for Metal files. Clear Nix-set variables that + # interfere with xcrun/xcodebuild's SDK and toolchain resolution. unset DEVELOPER_DIR - /usr/bin/xcrun $@ + unset SDKROOT + /usr/bin/xcrun "$@" ''; metalSupport = buildConfig.metal or false; @@ -123,13 +125,47 @@ stdenv.mkDerivation (prevAttrs: { # instead, we'll use showComponent (which will emit a lot of warnings due # to the above) to grab the path of the Metal toolchain. lib.optionalString metalSupport '' - METAL_PATH=$(${xcrunHost}/bin/xcrunHost xcodebuild -showComponent MetalToolchain 2> /dev/null | sed -rn "s/Toolchain Search Path: (.*)/\1/p") - if [ ! -d "$METAL_PATH" ]; then - >&2 echo "Cannot find Metal toolchain, use: xcodebuild -downloadComponent MetalToolchain" - exit 1 + # Try the separate Metal toolchain first (macOS 26+ with xcodebuild -downloadComponent). + # Use || true to prevent set -o pipefail from aborting on older macOS where + # -showComponent is unsupported. + METAL_PATH=$(${xcrunHost}/bin/xcrunHost xcodebuild -showComponent MetalToolchain 2> /dev/null | sed -rn "s/Toolchain Search Path: (.*)/\1/p" || true) + + if [ -d "$METAL_PATH/Metal.xctoolchain" ]; then + cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_PATH/Metal.xctoolchain") + else + # On macOS 14/15, xcrun and xcode-select may not work inside the Nix + # build environment (sandbox restrictions). Try them, then fall back + # to scanning /Applications for Xcode installations. + XCODE_DEV=$(${xcrunHost}/bin/xcrunHost xcode-select -p 2>/dev/null || true) + XCODE_TOOLCHAIN="$XCODE_DEV/Toolchains/XcodeDefault.xctoolchain" + + XCRUN_METAL=$(${xcrunHost}/bin/xcrunHost xcrun -find metal 2>/dev/null || true) + + if [ -d "$XCODE_TOOLCHAIN/usr/bin" ] && [ -f "$XCODE_TOOLCHAIN/usr/bin/metal" ]; then + cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$XCODE_TOOLCHAIN") + elif [ -n "$XCRUN_METAL" ] && [ -f "$XCRUN_METAL" ]; then + # Derive toolchain path from xcrun result + METAL_BIN_DIR=$(dirname "$XCRUN_METAL") + METAL_TC_DIR=$(dirname $(dirname "$METAL_BIN_DIR")) + cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_TC_DIR") + else + # Last resort: scan /Applications/Xcode*.app for metal compiler + FOUND_TC="" + for xcode_app in /Applications/Xcode*.app; do + TC="$xcode_app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain" + if [ -f "$TC/usr/bin/metal" ]; then + FOUND_TC="$TC" + break + fi + done + if [ -n "$FOUND_TC" ]; then + cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$FOUND_TC") + else + >&2 echo "Cannot find Metal toolchain. On macOS 26+, use: xcodebuild -downloadComponent metalToolchain" + exit 1 + fi + fi fi - - cmakeFlagsArray+=("-DMETAL_TOOLCHAIN=$METAL_PATH/Metal.xctoolchain") ''; # hipify copies files, but its target is run in the CMake build and install diff --git a/docs/source/builder/metal.md b/docs/source/builder/metal.md index 6095f376..5d464e24 100644 --- a/docs/source/builder/metal.md +++ b/docs/source/builder/metal.md @@ -5,24 +5,29 @@ Instructions on this page assume that you installed Nix with the ## Targeted macOS versions -Since new macOS versions get [adopted quickly](https://telemetrydeck.com/survey/apple/macOS/versions/), -we only support the latest major macOS version except for the first weeks -after a release, when we also support the previous major version. +Metal kernels are compiled with `-std=metal3.1` (AIR v26), which requires +macOS 15 or later on ARM64 (Apple Silicon). -We currently support macOS 26.0 and later on ARM64 (Apple silicon). +| macOS version | Support level | +|---------------|---------------| +| macOS 26+ | Fully supported and tested in CI | +| macOS 15 | Fully supported and tested in CI | +| macOS 14 | Best-effort (builds work, some tests may fail due to MPS memory limits) | ## Requirements To build a Metal kernel, the following requirements must be met: -- Xcode 26.x must be available on the build machine. -- `xcode-select -p` must point to the Xcode 26 installation, typically +- An Xcode installation with the Metal compiler must be available. The build + system automatically detects the Metal toolchain from available Xcode + installations. +- On macOS 26+, the Metal Toolchain is a separate download from Xcode: + `xcodebuild -downloadComponent MetalToolchain` +- On macOS 14/15, Metal ships bundled with Xcode (no separate install needed). +- `xcode-select -p` must point to your Xcode installation, typically `/Applications/Xcode.app/Contents/Developer`. If this is not the case, you can set the path with: `sudo xcode-select -s /path/to/Xcode.app/Contents/Developer` -- The Metal Toolchain must be installed. Starting with macOS 26, this is - a separate download from Xcode. You can install it with: - `xcodebuild -downloadComponent MetalToolchain` - The Nix sandbox should be set to `relaxed`, because the Nix derivation that builds the kernel must have access to Xcode and the Metal Toolchain. You can verify this by checking that `/etc/nix/nix.custom.conf` contains @@ -47,8 +52,7 @@ Xcode 26.1 Build version 17B55 ``` -The reported version must be 26.0 or newer. Then you can validate that the -Metal Toolchain is installed with: +On macOS 26+, you can validate that the Metal Toolchain is installed with: ```bash $ xcodebuild -showComponent metalToolchain diff --git a/docs/source/builder/writing-kernels.md b/docs/source/builder/writing-kernels.md index c5a32d50..5db6005d 100644 --- a/docs/source/builder/writing-kernels.md +++ b/docs/source/builder/writing-kernels.md @@ -35,6 +35,16 @@ as the running example. After reading this page, you may also want to have a look at the more realistic [ReLU kernel with backprop and `torch.compile`](https://github.com/huggingface/kernels/tree/main/builder/examples/relu-backprop-compile) support. +## Setting up environment + +In the [`terraform`](../../../terraform/) directory, we provide an +example of programatically spinning up an EC2 instance that is ready +with everything needed for you to start developing and building +kernels. + +If you use a different provider, the Terraform bridges should be +similar and straightforward to modify. + ## Kernel project layout Kernel projects follow this general directory layout: diff --git a/flake.nix b/flake.nix index c5d6ef0e..507dcdfa 100644 --- a/flake.nix +++ b/flake.nix @@ -169,6 +169,7 @@ tabulate tomlkit torch + tvm-ffi types-pyyaml types-requests types-tabulate diff --git a/kernel-abi-check/bindings/python/Cargo.lock b/kernel-abi-check/bindings/python/Cargo.lock index 0c6ccda3..3a90f34c 100644 --- a/kernel-abi-check/bindings/python/Cargo.lock +++ b/kernel-abi-check/bindings/python/Cargo.lock @@ -271,7 +271,7 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "kernel-abi-check" -version = "0.12.2-dev0" +version = "0.13.0-dev0" dependencies = [ "clap", "color-eyre", @@ -286,7 +286,7 @@ dependencies = [ [[package]] name = "kernel-abi-check-python" -version = "0.12.2-dev0" +version = "0.13.0-dev0" dependencies = [ "kernel-abi-check", "object", diff --git a/kernel-abi-check/bindings/python/Cargo.toml b/kernel-abi-check/bindings/python/Cargo.toml index 03777a8d..64b314f3 100644 --- a/kernel-abi-check/bindings/python/Cargo.toml +++ b/kernel-abi-check/bindings/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kernel-abi-check-python" -version = "0.12.2-dev0" +version = "0.13.0-dev0" edition = "2024" description = "Check the ABI of Hub Kernels" homepage = "https://github.com/huggingface/kernel-builder" diff --git a/kernel-abi-check/kernel-abi-check/Cargo.lock b/kernel-abi-check/kernel-abi-check/Cargo.lock index 5e25476f..b4783b6b 100644 --- a/kernel-abi-check/kernel-abi-check/Cargo.lock +++ b/kernel-abi-check/kernel-abi-check/Cargo.lock @@ -274,7 +274,7 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "kernel-abi-check" -version = "0.12.2-dev0" +version = "0.13.0-dev0" dependencies = [ "clap", "color-eyre", diff --git a/kernel-abi-check/kernel-abi-check/Cargo.toml b/kernel-abi-check/kernel-abi-check/Cargo.toml index 610e1044..a06ba57f 100644 --- a/kernel-abi-check/kernel-abi-check/Cargo.toml +++ b/kernel-abi-check/kernel-abi-check/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "kernel-abi-check" -version = "0.12.2-dev0" +version = "0.13.0-dev0" edition = "2021" description = "Check the ABI of Hub Kernels" homepage = "https://github.com/huggingface/kernel-builder" diff --git a/kernels/pyproject.toml b/kernels/pyproject.toml index 5b9900b0..055747a0 100644 --- a/kernels/pyproject.toml +++ b/kernels/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "kernels" -version = "0.12.2.dev0" +version = "0.13.0.dev0" description = "Download compute kernels" authors = [ { name = "Daniel de Kok", email = "daniel@huggingface.co" }, diff --git a/kernels/uv.lock b/kernels/uv.lock index 581b37dc..a36b3087 100644 --- a/kernels/uv.lock +++ b/kernels/uv.lock @@ -565,7 +565,7 @@ wheels = [ [[package]] name = "kernels" -version = "0.12.2.dev0" +version = "0.13.0.dev0" source = { editable = "." } dependencies = [ { name = "huggingface-hub" }, diff --git a/nix/overlay.nix b/nix/overlay.nix index 4fa05c3e..8beb2971 100644 --- a/nix/overlay.nix +++ b/nix/overlay.nix @@ -109,6 +109,9 @@ in torchVersion = "2.9"; xpuPackages = final.xpuPackages_2025_2; }; + + tvm-ffi = callPackage ./pkgs/python-modules/tvm-ffi { + }; } ) (import ./pkgs/python-modules/hooks) diff --git a/nix/pkgs/python-modules/tvm-ffi/default.nix b/nix/pkgs/python-modules/tvm-ffi/default.nix new file mode 100644 index 00000000..0e17df0f --- /dev/null +++ b/nix/pkgs/python-modules/tvm-ffi/default.nix @@ -0,0 +1,44 @@ +{ + callPackage, + buildPythonPackage, + fetchFromGitHub, + + cmake, + cython, + ninja, + python, + scikit-build-core, + setuptools-scm, +}: + +buildPythonPackage rec { + pname = "tvm-ffi"; + version = "0.1.9"; + pyproject = true; + + src = fetchFromGitHub { + owner = "apache"; + repo = "tvm-ffi"; + rev = "v${version}"; + hash = "sha256-XnlM//WW2TbjbmzYBq6itJQ7R3J646UMVQUVhV5Afwc="; + fetchSubmodules = true; + }; + + build-system = [ + cmake + cython + ninja + scikit-build-core + setuptools-scm + ]; + + dontUseCmakeConfigure = true; + + postInstall = '' + ln -s $out/${python.sitePackages}/tvm_ffi/share $out/share + ''; + + passthru = callPackage ./variant.nix { + tvmFfiVersion = version; + }; +} diff --git a/nix/pkgs/python-modules/tvm-ffi/variant.nix b/nix/pkgs/python-modules/tvm-ffi/variant.nix new file mode 100644 index 00000000..6536ff40 --- /dev/null +++ b/nix/pkgs/python-modules/tvm-ffi/variant.nix @@ -0,0 +1,46 @@ +{ + config, + cudaSupport ? config.cudaSupport, + metalSupport ? config.metalSupport or false, + rocmSupport ? config.rocmSupport, + xpuSupport ? config.xpuSupport or false, + + cudaPackages, + rocmPackages, + xpuPackages, + + lib, + stdenv, + + tvmFfiVersion, +}: + +let + flattenVersion = + version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.majorMinor (lib.versions.pad 2 version)); + backend = + if cudaSupport then + "cuda" + else if metalSupport then + "metal" + else if rocmSupport then + "rocm" + else if xpuSupport then + "xpu" + else + "cpu"; + computeString = + if cudaSupport then + "cu${flattenVersion cudaPackages.cudaMajorMinorVersion}" + else if metalSupport then + "metal" + else if rocmSupport then + "rocm${flattenVersion (lib.versions.majorMinor rocmPackages.rocm.version)}" + else if xpuSupport then + "xpu${flattenVersion (lib.versions.majorMinor xpuPackages.oneapi-torch-dev.version)}" + else + "cpu"; +in +{ + variant = "tvm-ffi${flattenVersion (lib.versions.majorMinor tvmFfiVersion)}-${computeString}-${stdenv.hostPlatform.system}"; +} diff --git a/template/__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm b/template/__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm index 060f5ec8..fe2d190c 100644 --- a/template/__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm +++ b/template/__KERNEL_NAME_NORMALIZED___metal/__KERNEL_NAME_NORMALIZED__.mm @@ -1,3 +1,4 @@ +#include #include #import @@ -25,7 +26,10 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input) "Tensors must be on same device"); @autoreleasepool { - id device = MTLCreateSystemDefaultDevice(); + at::mps::MPSStream *stream = at::mps::getCurrentMPSStream(); + TORCH_CHECK(stream, "Failed to get MPS stream"); + + id device = stream->device(); int numThreads = input.numel(); NSError *error = nil; @@ -42,9 +46,12 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input) [device newComputePipelineStateWithFunction:func error:&error]; TORCH_CHECK(pso, error.localizedDescription.UTF8String); - id cmdBuf = torch::mps::get_command_buffer(); - dispatch_sync(torch::mps::get_dispatch_queue(), ^() { - id encoder = [cmdBuf computeCommandEncoder]; + // Use stream->commandEncoder() to properly integrate with PyTorch's + // MPS encoder lifecycle (kernel coalescing). Creating encoders directly + // via [commandBuffer computeCommandEncoder] bypasses this and crashes + // when the kernel is called twice in sequence. + dispatch_sync(stream->queue(), ^() { + id encoder = stream->commandEncoder(); [encoder setComputePipelineState:pso]; [encoder setBuffer:getMTLBufferStorage(input) offset:input.storage_offset() * input.element_size() @@ -57,8 +64,8 @@ void __KERNEL_NAME_NORMALIZED__(torch::Tensor &out, torch::Tensor const &input) MIN(pso.maxTotalThreadsPerThreadgroup, (NSUInteger)numThreads); [encoder dispatchThreads:MTLSizeMake(numThreads, 1, 1) threadsPerThreadgroup:MTLSizeMake(tgSize, 1, 1)]; - [encoder endEncoding]; - torch::mps::commit(); }); + + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); } } diff --git a/terraform/.gitignore b/terraform/.gitignore new file mode 100644 index 00000000..2f1ff28b --- /dev/null +++ b/terraform/.gitignore @@ -0,0 +1,8 @@ +.terraform/ +*.tfstate +*.tfstate.backup +*.tfvars +!terraform.tfvars.example +.terraform.lock.hcl +crash.log +*.pem \ No newline at end of file diff --git a/terraform/README.md b/terraform/README.md new file mode 100644 index 00000000..4432ad20 --- /dev/null +++ b/terraform/README.md @@ -0,0 +1,106 @@ +# Terraform + +Terraform scripts to spin up an EC2 instance running NixOS, +mimicking the infra we use to develop and build the kernels +ourselves. + +An `m7i.8xlarge` NixOS instance (32 vCPUs, 128 GiB RAM) is provisioned +with a 200 GiB encrypted root volume and a 1 TiB encrypted gp3 data +volume mounted at `/data`. On first boot, the bundled NixOS configuration +is applied via `nixos-rebuild switch` and the Hugging Face Cachix binary +cache is registered. + +## Prerequisites + +- [Terraform](https://developer.hashicorp.com/terraform/install) ≥ 1.5 +- AWS credentials in the environment (`AWS_PROFILE`, `AWS_ACCESS_KEY_ID`, etc.) + +## Usage + +```bash +# 1. Configure +cp terraform.tfvars.example terraform.tfvars +$EDITOR terraform.tfvars # uncomment and adjust any overrides + +# 2. Deploy +terraform init +terraform apply + +# 3. Connect (SSH command is printed in outputs) +terraform output ssh_command +``` + +To push built Nix paths to the Cachix cache, set `cachix_auth_token` in `terraform.tfvars`. + +### Connecting + +After `terraform apply`, get the ready-to-use SSH command from the outputs: + +```bash +terraform output -raw ssh_command +``` + +This prints something like: + +```bash +ssh -i ~/.ssh/my-key.pem root@10.90.0.x +``` + +The instance is reachable via its **private IP** (the subnet does not auto-assign public IPs). + +### Waiting for first-boot setup + +After SSH-ing in, the NixOS configuration is applied in the background by the `amazon-init` service (downloading packages and running `nixos-rebuild switch`). This takes +**5-10 minutes**. Run this **inside the VM** to follow progress: + +```bash +journalctl -u amazon-init -f +``` + +The setup is complete when the service reaches `Finished` state, which you can confirm with: + +```bash +systemctl status amazon-init +``` + +> **Note:** `amazon-init` re-runs on every reboot, so the configuration is re-applied each time the instance restarts. + +Once done, reload the shell to pick up the aliases and settings: + +```bash +exec bash +``` + +### Inside the VM + +Once the setup is complete, a few useful aliases are available: + +```bash +ws # cd /data/workspace (1 TiB data volume) +ndc # nix develop -c $SHELL (enter the Nix dev shell) +nbd # nix build -L (build with full logs) +``` + +Typical workflow for working on a kernel: + +```bash +ws +git clone && cd +nix develop # or: ndc +``` + +`nix develop` must be run from inside a repo that has a `flake.nix` — running it from `/root` or any other directory without one will error. + +If you need a Nix dev shell **before** `amazon-init` finishes (i.e. before `nix-command` and `flakes` are enabled by the rebuild), pass the features explicitly from within the repo: + +```bash +nix --extra-experimental-features 'nix-command flakes' develop +``` + +`direnv` is also configured, so if the repo has a `.envrc` the dev shell activates automatically on `cd`. + +## Teardown + +```bash +terraform destroy +``` \ No newline at end of file diff --git a/terraform/main.tf b/terraform/main.tf new file mode 100644 index 00000000..bbe89da8 --- /dev/null +++ b/terraform/main.tf @@ -0,0 +1,134 @@ +terraform { + required_version = ">= 1.5" + + required_providers { + aws = { + source = "hashicorp/aws" + version = "~> 5.0" + } + } +} + +provider "aws" { + region = var.aws_region +} + +# --------------------------------------------------------------------------- +# Official NixOS AMI lookup +# AMIs are published weekly by the NixOS project under AWS account 427812963091. +# See: https://nixos.github.io/amis/ +# --------------------------------------------------------------------------- +data "aws_ami" "nixos" { + most_recent = true + owners = ["427812963091"] + + filter { + name = "name" + values = ["nixos/${var.nixos_channel}*"] + } + + filter { + name = "architecture" + values = ["x86_64"] + } +} + +locals { + common_tags = merge(var.tags, { + Project = "hf-kernels-dev" + ManagedBy = "terraform" + }) + + # Encode the NixOS configuration as base64 so it can be safely embedded in + # the user-data script without escaping issues (Nix files contain ${ ... }). + user_data = base64encode(join("", [ + "#!/bin/sh\n", + "set -e\n", + # Wait for the EBS data volume to be attached. + # Terraform attaches it after instance creation, so it may not be present + # immediately at boot. Poll for up to 5 minutes (30 x 10 s). + "echo 'Waiting for data volume /dev/nvme1n1...'\n", + "for i in $(seq 1 30); do\n", + " [ -b /dev/nvme1n1 ] && break\n", + " sleep 10\n", + "done\n", + # Format (first boot only) and mount the data volume, then create the + # directories that NixOS will later bind-mount over /nix/store. + "if [ -b /dev/nvme1n1 ]; then\n", + " if ! blkid /dev/nvme1n1 | grep -q ext4; then\n", + " mkfs.ext4 -L kernels-data /dev/nvme1n1\n", + " fi\n", + " mkdir -p /data\n", + " mount /dev/nvme1n1 /data\n", + " mkdir -p /data/nix-store /data/workspace\n", + "fi\n", + # Decode and write the NixOS configuration. + "base64 -d > /etc/nixos/configuration.nix << 'B64EOF'\n", + filebase64("${path.module}/nixos-configuration.nix"), + "\nB64EOF\n", + # Write the Cachix auth token if one was provided. + var.cachix_auth_token != "" ? join("", [ + "mkdir -p /root/.config/cachix\n", + "printf '{\\n authToken = \"${var.cachix_auth_token}\";\\n}\\n'", + " > /root/.config/cachix/cachix.dhall\n", + "chmod 600 /root/.config/cachix/cachix.dhall\n", + ]) : "", + # Apply the configuration (installs all packages including cachix). + "nixos-rebuild switch 2>&1 | tail -20\n", + # Register the huggingface Cachix binary cache — mirrors cachix-action@v16. + "cachix use huggingface\n", + ])) +} + +# --------------------------------------------------------------------------- +# EC2 instance running NixOS +# --------------------------------------------------------------------------- +resource "aws_instance" "kernels_dev" { + ami = data.aws_ami.nixos.id + instance_type = var.instance_type + key_name = var.key_pair_name + subnet_id = var.subnet_id + + associate_public_ip_address = true + + vpc_security_group_ids = [var.security_group_id] + + # NixOS configuration is applied on first boot via user data. + # Changing nixos-configuration.nix will replace the instance. + user_data = local.user_data + user_data_replace_on_change = true + + root_block_device { + volume_size = var.root_volume_size_gb + volume_type = "gp3" + delete_on_termination = true + encrypted = true + } + + metadata_options { + http_tokens = "required" # IMDSv2 + } + + tags = merge(local.common_tags, { Name = var.instance_name }) +} + +# --------------------------------------------------------------------------- +# Extra EBS data volume (Nix store spillover, build artefacts, source trees) +# --------------------------------------------------------------------------- +resource "aws_ebs_volume" "data" { + availability_zone = aws_instance.kernels_dev.availability_zone + size = var.data_volume_size_gb + type = var.data_volume_type + iops = var.data_volume_iops + throughput = var.data_volume_throughput + encrypted = true + + tags = merge(local.common_tags, { Name = "${var.instance_name}-data" }) +} + +resource "aws_volume_attachment" "data" { + device_name = "/dev/xvdf" + volume_id = aws_ebs_volume.data.id + instance_id = aws_instance.kernels_dev.id + force_detach = false +} diff --git a/terraform/nixos-configuration.nix b/terraform/nixos-configuration.nix new file mode 100644 index 00000000..e5e67be7 --- /dev/null +++ b/terraform/nixos-configuration.nix @@ -0,0 +1,181 @@ +# References: +# - https://github.com/huggingface/kernels-community/blob/main/.github/workflows/build-pr.yaml +{ + config, + pkgs, + lib, + ... +}: +{ + imports = [ + # Required for EC2 / AWS support (virtio drivers, cloud-init, EBS, etc.) + + ]; + + system.stateVersion = "25.11"; + + # ------------------------------------------------------------------------- + # Nix daemon — mirrors the CI configuration in build-pr.yaml: + # max-jobs = 2 / cores = 12 → here we use the full machine instead + # sandbox-fallback = false + # experimental-features = nix-command flakes + # substituters = huggingface cachix + # ------------------------------------------------------------------------- + nix.settings = { + experimental-features = [ + "nix-command" + "flakes" + ]; + # Sufficient to cater to heavy kernels. + max-jobs = 4; + cores = 16; + sandbox-fallback = false; + + substituters = [ + "https://cache.nixos.org" + "https://huggingface.cachix.org" + ]; + trusted-public-keys = [ + "cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=" + # From cachix-action@v16 in build-pr.yaml + "huggingface.cachix.org-1:ynTPbLS0W8ofXd9fDjk1KvoFky9K2jhxe6r4nXAkc/o=" + ]; + + # Allow the main user to add extra substituters without sudo. + trusted-users = [ + "root" + "nixos" + ]; + }; + + # Keep build outputs around so incremental rebuilds stay fast. + nix.gc = { + automatic = true; + dates = "weekly"; + options = "--delete-older-than 30d"; + }; + + # ------------------------------------------------------------------------- + # Data volume — format on first boot, then mount at /data. + # The 1 TiB gp3 EBS volume is attached as /dev/nvme1n1 on Nitro instances. + # /data/nix-store is bind-mounted over /nix/store so large builds do not + # fill the root volume. + # ------------------------------------------------------------------------- + systemd.services.format-data-volume = { + description = "Format the EBS data volume on first boot if needed"; + wantedBy = [ "multi-user.target" ]; + before = [ "data.mount" ]; + # Only run if the device exists (attachment can lag by a few seconds). + unitConfig.ConditionPathExists = "/dev/nvme1n1"; + script = '' + if ! ${pkgs.util-linux}/bin/blkid /dev/nvme1n1 | grep -q ext4; then + echo "Formatting /dev/nvme1n1 as ext4..." + ${pkgs.e2fsprogs}/bin/mkfs.ext4 -L kernels-data /dev/nvme1n1 + fi + ''; + serviceConfig = { + Type = "oneshot"; + RemainAfterExit = true; + }; + }; + + fileSystems."/data" = { + device = "/dev/nvme1n1"; + fsType = "ext4"; + options = [ + "nofail" + "x-systemd.requires=format-data-volume.service" + ]; + }; + + # Bind /nix/store onto the data volume so builds land on the 1 TiB disk. + fileSystems."/nix/store" = { + device = "/data/nix-store"; + fsType = "none"; + options = [ + "bind" + "nofail" + "x-systemd.requires=data.mount" + ]; + }; + + # Ensure the bind-mount target exists before mounting. + systemd.tmpfiles.rules = [ + "d /data/nix-store 0755 root root -" + "d /data/workspace 0755 nixos nixos -" + ]; + + # ------------------------------------------------------------------------- + # Packages for kernel development + # ------------------------------------------------------------------------- + environment.systemPackages = with pkgs; [ + # Version control & productivity + git + git-lfs + curl + wget + jq + ripgrep + htop + iotop + btop + tree + tmux + + # Nix ecosystem tooling + cachix # binary cache management + nix-tree # visualise the Nix store graph + nix-diff # compare two derivations + direnv # per-directory .envrc / nix develop auto-activation + nix-direnv # fast direnv integration for Nix + + # Compression (used by the CI closure export/import steps) + zstd + gzip + xz + + # Misc build utilities + patchelf + file + binutils + ]; + + # ------------------------------------------------------------------------- + # Shell environment + # ------------------------------------------------------------------------- + + # direnv hooks for bash and zsh so `nix develop` shells activate automatically. + programs.direnv = { + enable = true; + nix-direnv.enable = true; + }; + + # Useful shell aliases for kernel dev workflow. + environment.shellAliases = { + nbd = "nix build -L"; # build with logs + nbdt = "nix build -L .#ci-test"; # build the CI test output + ndc = "nix develop -c $SHELL"; # enter dev shell + ws = "cd /data/workspace"; + dinit = "echo 'use nix' > .envrc && direnv allow"; # init direnv for a flake dir + }; + + # ------------------------------------------------------------------------- + # SSH + # ------------------------------------------------------------------------- + services.openssh = { + enable = true; + settings = { + PermitRootLogin = "prohibit-password"; # key-only root login + PasswordAuthentication = false; + X11Forwarding = false; + }; + }; + + # ------------------------------------------------------------------------- + # Firewall — allow SSH only + # ------------------------------------------------------------------------- + networking.firewall = { + enable = true; + allowedTCPPorts = [ 22 ]; + }; +} diff --git a/terraform/outputs.tf b/terraform/outputs.tf new file mode 100644 index 00000000..f7902b36 --- /dev/null +++ b/terraform/outputs.tf @@ -0,0 +1,29 @@ +output "instance_id" { + description = "EC2 instance ID" + value = aws_instance.kernels_dev.id +} + +output "public_ip" { + description = "Public IP address of the instance" + value = aws_instance.kernels_dev.public_ip +} + +output "public_dns" { + description = "Public DNS name of the instance" + value = aws_instance.kernels_dev.public_dns +} + +output "ami_id" { + description = "NixOS AMI used for the instance" + value = data.aws_ami.nixos.id +} + +output "ami_name" { + description = "NixOS AMI name (includes channel and git revision)" + value = data.aws_ami.nixos.name +} + +output "ssh_command" { + description = "SSH command to connect to the instance" + value = "ssh -i ${var.ssh_private_key_path} root@${aws_instance.kernels_dev.private_ip}" +} diff --git a/terraform/terraform.tfvars.example b/terraform/terraform.tfvars.example new file mode 100644 index 00000000..616003e4 --- /dev/null +++ b/terraform/terraform.tfvars.example @@ -0,0 +1,20 @@ +# Copy this file to terraform.tfvars and fill in your values. +# terraform.tfvars is git-ignored — never commit it. + +# Required: +# subnet_id = "subnet-0123456789abcdef0" # subnet used when launching manually +# security_group_id = "sg-0123456789abcdef0" # existing SG that allows SSH (port 22) +# key_pair_name = "my-key-pair" # existing EC2 key pair name +# ssh_private_key_path = "~/.ssh/id_rsa" # local path to the corresponding private key (".pem"), +# this should correspond to the one being specified to `key_pair_name` + +# Optional overrides (all have sensible defaults): +# instance_name = "kernels-dev" +# aws_region = "us-east-1" +# nixos_channel = "25.11" +# instance_type = "m7i.8xlarge" # 32 vCPU / 128 GiB RAM +# root_volume_size_gb = 200 +# data_volume_size_gb = 1000 # 1 TiB gp3 for Nix store + builds +# data_volume_iops = 6000 +# data_volume_throughput = 400 +# cachix_auth_token = "your-cachix-token" # only needed for pushing to the cache diff --git a/terraform/variables.tf b/terraform/variables.tf new file mode 100644 index 00000000..e8905c13 --- /dev/null +++ b/terraform/variables.tf @@ -0,0 +1,89 @@ +variable "aws_region" { + description = "AWS region to deploy into" + type = string + default = "us-east-1" +} + +variable "nixos_channel" { + description = "NixOS channel to use for the AMI (e.g. '25.11' or '24.11')" + type = string + default = "25.11" +} + +variable "instance_name" { + description = "Name tag for the EC2 instance (and related resources)" + type = string + default = "kernels-dev" +} + +variable "instance_type" { + description = "EC2 instance type — heavy on CPU and RAM" + type = string + # 32 vCPUs, 128 GiB RAM + default = "m7i.8xlarge" +} + +variable "root_volume_size_gb" { + description = "Size of the root EBS volume in GiB" + type = number + default = 200 +} + +variable "data_volume_size_gb" { + description = "Size of the extra data EBS volume in GiB (Nix store, builds, source trees)" + type = number + default = 1000 +} + +variable "data_volume_type" { + description = "EBS volume type for the data volume" + type = string + default = "gp3" +} + +variable "data_volume_iops" { + description = "Provisioned IOPS for the data volume (gp3 baseline is 3000)" + type = number + default = 6000 +} + +variable "data_volume_throughput" { + description = "Provisioned throughput in MiB/s for the data volume (gp3 baseline is 125)" + type = number + default = 400 +} + +variable "subnet_id" { + description = "ID of the subnet to launch the instance in." + type = string +} + +variable "security_group_id" { + description = "ID of an existing security group to attach to the instance." + type = string +} + +variable "key_pair_name" { + description = "Name of an existing EC2 key pair to attach to the instance." + type = string +} + +variable "ssh_private_key_path" { + description = "Local path to the private key file corresponding to key_pair_name (used in the ssh_command output)." + type = string + default = "~/.ssh/id_rsa" +} + + +variable "cachix_auth_token" { + description = "Cachix auth token for pushing to the huggingface cache (optional)" + type = string + default = "" + sensitive = true +} + +variable "tags" { + description = "Extra tags to apply to all resources" + type = map(string) + default = {} +}