-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathsetup.py
More file actions
77 lines (63 loc) · 2.36 KB
/
setup.py
File metadata and controls
77 lines (63 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import importlib.util
import os
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
__version__ = None
exec(open("editable_gauss_refl/version.py", "r").read())
def get_package_path(package_name):
"""
Returns the filesystem path to the root of the given package,
whether it's installed in editable mode or normally.
:param package_name: The name of the package to locate.
:return: Absolute path to the package directory.
:raises ImportError: If the package cannot be found.
"""
spec = importlib.util.find_spec(package_name)
if spec is None or spec.origin is None:
raise ImportError(f"Cannot find package '{package_name}'")
# If it's a module (not a package), return the file's directory
if spec.submodule_search_locations is None:
return os.path.dirname(spec.origin)
# It's a package: return the top-level package path
return os.path.abspath(spec.submodule_search_locations[0])
# Custom build extension to build the OptiX tracing kernel
class CustomBuildExtension(BuildExtension):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def build_extensions(self):
# Run the original build_extensions
super().build_extensions()
# Build OptiX library
pkg_source = os.path.dirname(os.path.abspath(__file__))
pkg_target = get_package_path("editable_gauss_refl")
if not os.path.exists(pkg_target):
os.makedirs(pkg_target, exist_ok=True)
os.system(f"mkdir -p {pkg_source}/editable_gauss_refl/cuda/build && cd {pkg_source}/editable_gauss_refl/cuda/build && cmake .. && make")
setup(
name="editable_gauss_refl",
version=__version__,
description=" Python package for differentiable tracing of gaussians",
keywords="gaussian, raytracing, cuda",
python_requires=">=3.10",
install_requires=[
"ninja",
"numpy<2.0.0",
"torch",
],
extras_require={
"dev": [
"clang-format",
"pytest",
"ruff",
],
},
ext_modules=[
CUDAExtension(
name="editable_gauss_refl._C",
sources=["editable_gauss_refl/cuda/ext.cpp"],
include_dirs=[],
),
],
cmdclass={"build_ext": CustomBuildExtension},
packages=find_packages(),
)