-
Notifications
You must be signed in to change notification settings - Fork 667
Expand file tree
/
Copy pathjax.py
More file actions
121 lines (95 loc) · 3.37 KB
/
jax.py
File metadata and controls
121 lines (95 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX related extensions."""
import os
from pathlib import Path
from packaging import version
import setuptools
from .utils import (
get_cuda_include_dirs,
all_files_in_dir,
debug_build_enabled,
get_cuda_library_dirs,
)
from typing import List
def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
return ["jax", "flax>=0.7.1"]
def test_requirements() -> List[str]:
"""Test dependencies for TE/JAX extensions.
Triton Package Selection:
The triton package is selected based on NVTE_USE_PYTORCH_TRITON environment variable:
Default (NVTE_USE_PYTORCH_TRITON unset or "0"):
Returns 'triton' - OpenAI's standard package from PyPI.
Install with: pip install triton
NVTE_USE_PYTORCH_TRITON=1:
Returns 'pytorch-triton' - for mixed JAX+PyTorch environments.
Install with: pip install pytorch-triton --index-url https://download.pytorch.org/whl/cu121
Note: Do NOT install pytorch-triton from PyPI directly - that's a placeholder.
"""
use_pytorch_triton = bool(int(os.environ.get("NVTE_USE_PYTORCH_TRITON", "0")))
triton_package = "pytorch-triton" if use_pytorch_triton else "triton"
return [
"numpy",
triton_package,
]
def xla_path() -> str:
"""XLA root path lookup.
Throws FileNotFoundError if XLA source is not found."""
try:
import jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
else:
xla_home = "/opt/xla"
else:
xla_home = ffi.include_dir()
if not os.path.isdir(xla_home):
raise FileNotFoundError("Could not find xla source.")
return xla_home
def setup_jax_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup PyBind11 extension for JAX support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = all_files_in_dir(extensions_dir, name_extension="cpp")
# Header files
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
xla_path(),
]
)
# Library dirs
library_dirs = get_cuda_library_dirs()
# Compile flags
cxx_flags = ["-O3"]
if debug_build_enabled():
cxx_flags.append("-g")
cxx_flags.append("-UNDEBUG")
else:
cxx_flags.append("-g0")
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
return Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
library_dirs=[str(path) for path in library_dirs],
)