Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use_repo(
"cuda_nvtx",
"cuda_nvvm",
"cuda_profiler_api",
"cuda_nvptxcompiler",
)

##############################################################
Expand Down
40 changes: 40 additions & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,46 @@ config_setting(
flag_values = {":enable_cuda": "False"},
)

# Flag for linking static CUDA libs
bool_flag(
name = "link_cuda_static_libs",
build_setting_default = False,
)

config_setting(
name = "is_cuda_static_linking_enabled",
flag_values = {
":link_cuda_static_libs": "True",
},
)


# Flag for linking static CUDA NVRTC libs
bool_flag(
name = "link_nvrtc_static_libs",
build_setting_default = False,
)

config_setting(
name = "is_nvrtc_static_linking_enabled",
flag_values = {
":link_nvrtc_static_libs": "True",
},
)

# Flag for linking static CUDA CUDNN libs
bool_flag(
name = "link_cudnn_static_libs",
build_setting_default = False,
)

config_setting(
name = "is_cudnn_static_linking_enabled",
flag_values = {
":link_cudnn_static_libs": "True",
},
)

#######################################################
# Enable SYCL support flags

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,17 @@ def InvokeNvcc(argv, log=False):
# Unfortunately, there are other options that have -c prefix too.
# So allowing only those look like C/C++ files.
src_files = [f for f in src_files if
re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$|\.cu$', f)]
re.search(r'\.cpp$|\.cc$|\.c$|\.cxx$|\.C$|\.cu$', f)]
srcs = ' '.join(src_files)
out = ' -o ' + out_file[0]

nvccopts = '-D_FORCE_INLINES '
capabilities_sm = set(get_option_value(argv, '--cuda-gpu-arch')) - set(
get_option_value(argv, '--no-cuda-gpu-arch')
capabilities_sm = set(GetOptionValue(argv, '--cuda-gpu-arch')) - set(
GetOptionValue(argv, '--no-cuda-gpu-arch')
)
capabilities_compute = set(
get_option_value(argv, '--cuda-include-ptx')
) - set(get_option_value(argv, '--no-cuda-include-ptx'))
GetOptionValue(argv, '--cuda-include-ptx')
) - set(GetOptionValue(argv, '--no-cuda-include-ptx'))
# When both "code=sm_xy" and "code=compute_xy" are requested for a single
# arch, they can be combined using "code=xy,compute_xy" which avoids a
# redundant PTX generation during compilation.
Expand All @@ -257,6 +257,7 @@ def InvokeNvcc(argv, log=False):
nvccopts += std_options
nvccopts += m_options
nvccopts += warning_options
# nvccopts += ' -rdc=true '
# Force C++17 dialect (note, everything in just one string!)
nvccopts += ' --std c++17 '
nvccopts += fatbin_options
Expand Down
39 changes: 38 additions & 1 deletion third_party/gpus/cuda/build_defs.bzl.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,47 @@ def if_cuda(if_true, if_false = []):
with CUDA enabled. Otherwise, the select statement evaluates to if_false.
"""
return select({
"@local_config_cuda//:is_cuda_enabled": if_true,
"@rules_ml_toolchain//common:is_cuda_enabled": if_true,
"//conditions:default": if_false,
})

# Macros for building CUDA static code.
def if_static_cuda(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with static CUDA libs.

Returns a select statement which evaluates to if_true if we're building
with static CUDA enabled. Otherwise, the select statement evaluates to if_false.
"""
return select({
"@rules_ml_toolchain//common:is_cuda_static_linking_enabled": if_true,
"//conditions:default": if_false,
})

# Macros for building NVRTC static code.
def if_static_nvrtc(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with static NVRTC libs.

Returns a select statement which evaluates to if_true if we're building
with static NVRTC enabled. Otherwise, the select statement evaluates to if_false.
"""
return select({
"@rules_ml_toolchain//common:is_nvrtc_static_linking_enabled": if_true,
"//conditions:default": if_false,
})

# Macros for building CUDNN static code.
def if_static_cudnn(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with static CUDNN libs.

Returns a select statement which evaluates to if_true if we're building
with static CUDNN enabled. Otherwise, the select statement evaluates to if_false.
"""
return select({
"@rules_ml_toolchain//common:is_cudnn_static_linking_enabled": if_true,
"//conditions:default": if_false,
})


def if_cuda_clang(if_true, if_false = []):
"""Shorthand for select()'ing on wheteher we're building with cuda-clang.

Expand Down
5 changes: 2 additions & 3 deletions third_party/gpus/cuda/hermetic/BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,9 @@ selects.config_setting_group(
],
)

cc_library(
# This is not yet fully supported, but we need the rule
# to make bazel query happy.
alias(
name = "nvptxcompiler",
actual = "@cuda_nvcc//:nvptxcompiler",
)

alias(
Expand Down
2 changes: 2 additions & 0 deletions third_party/gpus/cuda/hermetic/cuda_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ load("@cuda_nvml//:version.bzl", _nvml_version = "VERSION")
load("@cuda_nvtx//:version.bzl", _nvtx_version = "VERSION")
load("@cuda_nvvm//:version.bzl", _nvvm_version = "VERSION")
load("@cuda_profiler_api//:version.bzl", _cuda_profiler_api_version = "VERSION")
load("@cuda_nvptxcompiler//:version.bzl", _cuda_nvptxcompiler_version = "VERSION")
load("@llvm_linux_aarch64//:version.bzl", _llvm_aarch64_hermetic_version = "VERSION")
load("@llvm_linux_x86_64//:version.bzl", _llvm_x86_64_hermetic_version = "VERSION")
load(
Expand Down Expand Up @@ -366,6 +367,7 @@ def _get_cuda_config(repository_ctx):
cupti_version = _cupti_version,
cudart_version = _cudart_version,
cuda_profiler_api_version = _cuda_profiler_api_version,
cuda_nvptxcompiler_version = _cuda_nvptxcompiler_version,
cublas_version = _cublas_version,
cusolver_version = _cusolver_version,
curand_version = _curand_version,
Expand Down
24 changes: 21 additions & 3 deletions third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ licenses(["restricted"]) # NVIDIA proprietary license
load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda_newer_than",
"if_static_cuda",
)
load(
"@rules_ml_toolchain//third_party/gpus:nvidia_common_rules.bzl",
Expand All @@ -13,19 +14,33 @@ cc_import(
name = "cublas_shared_library",
hdrs = [":headers"],
shared_library = "lib/libcublas.so.%{libcublas_version}",
deps = [":cublasLt"],
)

cc_import(
name = "cublasLt_shared_library",
hdrs = [":headers"],
shared_library = "lib/libcublasLt.so.%{libcublaslt_version}",
)

cc_import(
name = "cublasLt_static_library",
hdrs = [":headers"],
static_library = "lib/libcublasLt_static.a",
)

cc_import(
name = "cublas_static_library",
hdrs = [":headers"],
static_library = "lib/libcublas_static.a",
)
%{multiline_comment}
cc_library(
name = "cublas",
visibility = ["//visibility:public"],
%{comment}deps = [":cublas_shared_library"],
%{comment}deps = if_static_cuda(
%{comment}[":cublas_static_library"],
%{comment}[":cublas_shared_library"],
%{comment}) + [":cublasLt"],
%{comment}linkopts = if_cuda_newer_than(
%{comment}"13_0",
%{comment}if_true = cuda_rpath_flags("nvidia/cu13/lib"),
Expand All @@ -36,7 +51,10 @@ cc_library(
cc_library(
name = "cublasLt",
visibility = ["//visibility:public"],
%{comment}deps = [":cublasLt_shared_library"],
%{comment}deps = if_static_cuda(
%{comment}[":cublasLt_static_library"],
%{comment}[":cublasLt_shared_library"],
%{comment}),
%{comment}linkopts = if_cuda_newer_than(
%{comment}"13_0",
%{comment}if_true = cuda_rpath_flags("nvidia/cu13/lib"),
Expand Down
26 changes: 23 additions & 3 deletions third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ licenses(["restricted"]) # NVIDIA proprietary license
load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda_newer_than",
"if_static_cuda",
)
load(
"@rules_ml_toolchain//third_party/gpus:nvidia_common_rules.bzl",
Expand Down Expand Up @@ -29,6 +30,24 @@ cc_import(
hdrs = [":headers"],
shared_library = "lib/libcudart.so.%{libcudart_version}",
)

cc_import(
name = "cudart_static_library",
hdrs = [":headers"],
static_library = "lib/libcudart_static.a",
)

cc_import(
name = "culibos_static_library",
hdrs = [":headers"],
static_library = if_cuda_newer_than("13_0", None, "lib/libculibos.a"),
)

cc_import(
name = "cudadevrt_static_library",
hdrs = [":headers"],
static_library = "lib/libcudadevrt.a",
)
%{multiline_comment}
cc_library(
name = "cuda_driver",
Expand All @@ -44,9 +63,10 @@ cc_library(
%{comment}"@cuda_driver//:nvidia_ptxjitcompiler",
%{comment}],
%{comment}"//conditions:default": [":cuda_driver"],
%{comment}}) + [
%{comment}":cudart_shared_library",
%{comment}],
%{comment}}) + if_static_cuda(
%{comment}[":cudart_static_library", ":cudadevrt_static_library"] + if_cuda_newer_than("13_0", ["@cuda_culibos//:culibos_static_library"], [":culibos_static_library"]),
%{comment}[":cudart_shared_library"],
%{comment}),
%{comment}linkopts = if_cuda_newer_than(
%{comment}"13_0",
%{comment}if_true = cuda_rpath_flags("nvidia/cu13/lib"),
Expand Down
64 changes: 60 additions & 4 deletions third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
licenses(["restricted"]) # NVIDIA proprietary license
load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_static_cudnn",
)
load(
"@rules_ml_toolchain//third_party/gpus:nvidia_common_rules.bzl",
"cuda_rpath_flags",
Expand Down Expand Up @@ -52,20 +56,72 @@ cc_import(
hdrs = [":headers"],
shared_library = "lib/libcudnn.so.%{libcudnn_version}",
)

cc_import(
name = "cudnn_graph_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_graph_static_v9.a",
)

cc_import(
name = "cudnn_adv_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_adv_static_v9.a",
)

cc_import(
name = "cudnn_engines_runtime_compiled_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_engines_runtime_compiled_static_v9.a",
)

cc_import(
name = "cudnn_engines_precompiled_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_engines_precompiled_static_v9.a",
)

cc_import(
name = "cudnn_ops_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_ops_static_v9.a",
)

cc_import(
name = "cudnn_heuristic_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_heuristic_static_v9.a",
)

cc_import(
name = "cudnn_cnn_static",
hdrs = [":headers"],
static_library = "lib/libcudnn_cnn_static_v9.a",
)
%{multiline_comment}
cc_library(
name = "cudnn",
%{comment}deps = [
%{comment}":cudnn_engines_precompiled",
%{comment}alwayslink = if_static_cudnn(True, False),
%{comment}srcs = if_static_cudnn(
%{comment}[":lib/libcudnn_engines_precompiled_static_v9.a",
%{comment} ":lib/libcudnn_ops_static_v9.a",
%{comment} ":lib/libcudnn_cnn_static_v9.a",
%{comment} ":lib/libcudnn_adv_static_v9.a",
%{comment} ":lib/libcudnn_heuristic_static_v9.a",
%{comment} ":lib/libcudnn_graph_static_v9.a",
%{comment} ":lib/libcudnn_engines_runtime_compiled_static_v9.a",
%{comment}], []),
%{comment}deps = if_static_cudnn(
%{comment}[],
%{comment}[":cudnn_engines_precompiled",
%{comment}":cudnn_ops",
%{comment}":cudnn_graph",
%{comment}":cudnn_cnn",
%{comment}":cudnn_adv",
%{comment}":cudnn_engines_runtime_compiled",
%{comment}":cudnn_heuristic",
%{comment}"@cuda_nvrtc//:nvrtc",
%{comment}":cudnn_main",
%{comment}],
%{comment}]) + ["@cuda_nvrtc//:nvrtc"],
%{comment}linkopts = cuda_rpath_flags("nvidia/cudnn/lib"),
visibility = ["//visibility:public"],
)
Expand Down
23 changes: 21 additions & 2 deletions third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ licenses(["restricted"]) # NVIDIA proprietary license
load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda_newer_than",
"if_static_cuda",
)
load(
"@rules_ml_toolchain//third_party/gpus:nvidia_common_rules.bzl",
Expand All @@ -14,15 +15,33 @@ cc_import(
hdrs = [":headers"],
shared_library = "lib/libcufft.so.%{libcufft_version}",
)

cc_import(
name = "cufft_static_library",
hdrs = [":headers"],
static_library = "lib/libcufft_static.a",
)

cc_import(
name = "cufftw_static_library",
hdrs = [":headers"],
static_library = "lib/libcufftw_static.a",
)

cc_import(
name = "cufft_static_nocallback_library",
hdrs = [":headers"],
static_library = if_cuda_newer_than("13_0", None, "lib/libcufft_static_nocallback.a"),
)
%{multiline_comment}
cc_library(
name = "cufft",
%{comment}deps = [":cufft_shared_library"],
%{comment}deps = if_static_cuda(if_cuda_newer_than("13_0", [":cufft_static_library"], [":cufft_static_nocallback_library"]) + [":cufftw_static_library"], [":cufft_shared_library"]),
%{comment}linkopts = if_cuda_newer_than(
%{comment}"13_0",
%{comment}if_true = cuda_rpath_flags("nvidia/cu13/lib"),
%{comment}if_false = cuda_rpath_flags("nvidia/cufft/lib"),
%{comment}),
%{comment}) + if_static_cuda(["-Wl,--no-relax"]),
visibility = ["//visibility:public"],
)

Expand Down
Loading