Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build2cmake/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use config::{Backend, Build, BuildCompat};
mod fileset;
use fileset::FileSet;

mod metadata;

mod version;

#[derive(Parser, Debug)]
Expand Down
15 changes: 15 additions & 0 deletions build2cmake/src/metadata.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, Serialize)]
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
pub struct Metadata {
python_depends: Vec<String>,
}

impl Metadata {
pub fn new(python_depends: impl Into<Vec<String>>) -> Self {
Self {
python_depends: python_depends.into(),
}
}
}
5 changes: 5 additions & 0 deletions build2cmake/src/templates/windows.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ function(add_local_install_target TARGET_NAME PACKAGE_NAME BUILD_VARIANT_NAME)
${PYTHON_FILES}
${LOCAL_INSTALL_DIR}/

# Copy metadata.json if it exists
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${CMAKE_SOURCE_DIR}/metadata.json
${LOCAL_INSTALL_DIR}/

COMMENT "Copying shared library and Python files to ${LOCAL_INSTALL_DIR}"
COMMAND_EXPAND_LISTS
)
Expand Down
16 changes: 16 additions & 0 deletions build2cmake/src/torch/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use itertools::Itertools;
use minijinja::{context, Environment};

use crate::config::{Backend, General};
use crate::metadata::Metadata;
use crate::FileSet;

pub fn write_pyproject_toml(
Expand Down Expand Up @@ -32,3 +33,18 @@ pub fn write_pyproject_toml(

Ok(())
}

pub fn write_metadata(backend: Backend, general: &General, file_set: &mut FileSet) -> Result<()> {
let writer = file_set.entry("metadata.json");

let python_depends = general
.python_depends()
.chain(general.backend_python_depends(backend))
.collect::<Result<Vec<_>>>()?;

let metadata = Metadata::new(python_depends);

serde_json::to_writer(writer, &metadata)?;

Ok(())
}
3 changes: 3 additions & 0 deletions build2cmake/src/torch/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier};
use crate::{
config::{Backend, Build, Kernel, Torch},
fileset::FileSet,
torch::common::write_metadata,
version::Version,
};

Expand Down Expand Up @@ -52,6 +53,8 @@ pub fn write_torch_ext_cpu(

write_torch_registration_macros(&mut file_set)?;

write_metadata(Backend::Cpu, &build.general, &mut file_set)?;

Ok(file_set)
}

Expand Down
3 changes: 3 additions & 0 deletions build2cmake/src/torch/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use minijinja::{context, Environment};
use super::common::write_pyproject_toml;
use super::kernel_ops_identifier;
use crate::config::{Backend, Build, Dependency, Kernel, Torch};
use crate::torch::common::write_metadata;
use crate::version::Version;
use crate::FileSet;

Expand Down Expand Up @@ -65,6 +66,8 @@ pub fn write_torch_ext_cuda(

write_torch_registration_macros(&mut file_set)?;

write_metadata(backend, &build.general, &mut file_set)?;

Ok(file_set)
}

Expand Down
3 changes: 3 additions & 0 deletions build2cmake/src/torch/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::{common::write_pyproject_toml, kernel_ops_identifier};
use crate::{
config::{Backend, Build, Kernel, Torch},
fileset::FileSet,
torch::common::write_metadata,
version::Version,
};

Expand Down Expand Up @@ -54,6 +55,8 @@ pub fn write_torch_ext_metal(

write_torch_registration_macros(&mut file_set)?;

write_metadata(Backend::Metal, &build.general, &mut file_set)?;

Ok(file_set)
}

Expand Down
4 changes: 3 additions & 1 deletion build2cmake/src/torch/noarch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use minijinja::{context, Environment};
use crate::{
config::{Backend, Build, General, Torch},
fileset::FileSet,
torch::kernel_ops_identifier,
torch::{common::write_metadata, kernel_ops_identifier},
};

pub fn write_torch_ext_noarch(
Expand All @@ -30,6 +30,8 @@ pub fn write_torch_ext_noarch(
&mut file_set,
)?;

write_metadata(backend, &build.general, &mut file_set)?;

Ok(file_set)
}

Expand Down
3 changes: 3 additions & 0 deletions build2cmake/src/torch/xpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use minijinja::{context, Environment};
use super::common::write_pyproject_toml;
use super::kernel_ops_identifier;
use crate::config::{Backend, Build, Dependency, Kernel, Torch};
use crate::torch::common::write_metadata;
use crate::version::Version;
use crate::FileSet;

Expand Down Expand Up @@ -53,6 +54,8 @@ pub fn write_torch_ext_xpu(

write_torch_registration_macros(&mut file_set)?;

write_metadata(Backend::Xpu, &build.general, &mut file_set)?;

Ok(file_set)
}

Expand Down
8 changes: 1 addition & 7 deletions lib/torch-extension/arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ let

moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName;

metadata = builtins.toJSON {
python-depends = pythonDeps;
};

metadataFile = writeText "metadata.json" metadata;

# On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders.
# It's not supported by the nixpkgs shim.
xcrunHost = writeScriptBin "xcrunHost" ''
Expand Down Expand Up @@ -255,7 +249,7 @@ stdenv.mkDerivation (prevAttrs: {
mkdir $out/${moduleName}
cp ${./compat.py} $out/${moduleName}/__init__.py

cp ${metadataFile} $out/metadata.json
cp ../metadata.json $out/
''
+ (lib.optionalString (stripRPath && stdenv.hostPlatform.isLinux)) ''
find $out/ -name '*.so' \
Expand Down
6 changes: 1 addition & 5 deletions lib/torch-extension/no-arch.nix
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ let
++ resolveBackendPythonDeps buildConfig.backend backendPythonDeps
++ [ torch ];
moduleName = builtins.replaceStrings [ "-" ] [ "_" ] kernelName;
metadata = builtins.toJSON {
python-depends = pythonDeps;
};
metadataFile = writeText "metadata.json" metadata;
metalSupport = buildConfig.metal or false;
in

Expand Down Expand Up @@ -94,7 +90,7 @@ stdenv.mkDerivation (prevAttrs: {
cp -r torch-ext/${moduleName}/* $out/
mkdir $out/${moduleName}
cp ${./compat.py} $out/${moduleName}/__init__.py
cp ${metadataFile} $out/metadata.json
cp metadata.json $out/
'';

doInstallCheck = true;
Expand Down
Loading