Skip to content

Commit 3b803f1

Browse files
committed
Attempt to get proper, recent version of CUDA_full_jll for each CUDA version
1 parent dff04c0 commit 3b803f1

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

T/Torch/TorchCUDA/build_tarballs.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,22 @@ products = [
5757
]
5858

5959
dependencies = [
60-
BuildDependency("CUDA_full_jll"),
6160
Dependency("CUDNN_jll")
6261
]
6362

6463
cuda_versions = [v"10.2", v"11.0", v"11.1", v"11.2", v"11.3", v"11.4", v"11.5", v"11.6"]
64+
65+
cuda_full_version = Dict{VersionNumber,VersionNumber}(
66+
v"10.2" => v"10.2.89",
67+
v"11.0" => v"11.0.3",
68+
v"11.1" => v"11.1.1",
69+
v"11.2" => v"11.2.2",
70+
v"11.3" => v"11.3.1",
71+
v"11.4" => v"11.4.2",
72+
v"11.5" => v"11.5.1",
73+
v"11.6" => v"11.6.0"
74+
)
75+
6576
for cuda_version in cuda_versions
6677
cuda_tag = "$(cuda_version.major).$(cuda_version.minor)"
6778
if cuda_version.major == 10
@@ -73,8 +84,11 @@ for cuda_version in cuda_versions
7384
for platform in platforms
7485
augmented_platform = Platform(arch(platform), os(platform); cuda=cuda_tag)
7586
should_build_platform(triplet(augmented_platform)) || continue
87+
platform_dependencies = vcat(dependencies, [
88+
HostBuildDependency(PackageSpec("CUDA_full_jll", Base.UUID("4f82f1eb-248c-5f56-a42e-99106d144614"), cuda_full_version[cuda_version]))
89+
])
7690
build_tarballs(ARGS, name, version, sources, script, [augmented_platform],
77-
products, dependencies; lazy_artifacts=true,
91+
products, platform_dependencies; lazy_artifacts=true,
7892
preferred_gcc_version = v"7.1.0")
7993
end
8094
end

0 commit comments

Comments
 (0)