Skip to content

Commit a6dd185

Browse files
authored
feat: various fixups for nicer JETLS interaction (#1630)
* feat: improve 1.12 support via workspaces * fix: drop argcheck (for JETLS) * fix: more JETLS fixups * ci: add UUIDs to skip * fix: incorrect assert usage * fix: strict test * fix: multiple luxlib fixes * test: fixes
1 parent c1a72f6 commit a6dd185

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+176
-145
lines changed

.JuliaLint.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
testitem-errors = false

.github/workflows/CommonCI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ jobs:
6969
mode: forcedeps
7070
projects: ${{ inputs.project }}
7171
julia_version: ${{ inputs.julia_version }}
72-
skip: Pkg, TOML, Statistics, LinearAlgebra, Random, Serialization, Markdown, Test, LuxCore, LuxLib, LuxTestUtils, MLDataDevices, WeightInitializers
72+
skip: Pkg, TOML, Statistics, LinearAlgebra, Random, Serialization, Markdown, Test, LuxCore, LuxLib, LuxTestUtils, MLDataDevices, WeightInitializers, UUIDs
7373

7474
# For 1.10 we need to manually develop the packages.
7575
- name: "Develop Dependencies"

Project.toml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.29.2"
4+
version = "1.29.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
9-
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
109
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1110
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1211
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
@@ -35,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3534
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
3635
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3736
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
37+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
3838
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
3939

4040
[weakdeps]
@@ -90,7 +90,6 @@ ZygoteExt = "Zygote"
9090
[compat]
9191
ADTypes = "1.15"
9292
Adapt = "4.1"
93-
ArgCheck = "2.3"
9493
ArrayInterface = "7.17.1"
9594
CUDA = "5.8"
9695
ChainRulesCore = "1.25.1"
@@ -109,7 +108,7 @@ GPUArrays = "11"
109108
GPUArraysCore = "0.2"
110109
LinearAlgebra = "1.10"
111110
LossFunctions = "0.11.1, 1"
112-
LuxCore = "1.5"
111+
LuxCore = "1.5.1"
113112
LuxLib = "1.15"
114113
MLDataDevices = "1.17"
115114
MLUtils = "0.4.4"
@@ -134,6 +133,10 @@ Static = "1.1.1"
134133
StaticArraysCore = "1.4.3"
135134
Statistics = "1.10"
136135
Tracker = "0.2.37"
136+
UUIDs = "1.10"
137137
WeightInitializers = "1.3"
138138
Zygote = "0.7"
139139
julia = "1.10"
140+
141+
[workspaces]
142+
projects = ["test", "docs"]

ext/FluxExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module FluxExt
22

3-
using ArgCheck: @argcheck
43
using Flux: Flux
54

65
using Lux: Lux, FluxModelConversionException, LuxOps
@@ -284,7 +283,7 @@ function Lux.convert_flux_model(
284283
l::Flux.GroupNorm; preserve_ps_st::Bool=false, force_preserve::Bool=false
285284
)
286285
if preserve_ps_st
287-
@argcheck !l.track_stats
286+
@assert !l.track_stats
288287
if l.affine
289288
return Lux.GroupNorm(
290289
l.chs,

ext/MPINCCLExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module MPINCCLExt
22

3-
using ArgCheck: @argcheck
43
using MPI: MPI
54
using NCCL: NCCL
65
using Setfield: @set!
@@ -13,7 +12,7 @@ Lux.is_extension_loaded(::Val{:MPINCCL}) = true
1312
function DistributedUtils.force_initialize(
1413
::Type{NCCLBackend}; cuda_devices=nothing, amdgpu_devices=missing
1514
)
16-
@argcheck amdgpu_devices === missing "`AMDGPU` is not supported by `NCCL`."
15+
@assert amdgpu_devices === missing "`AMDGPU` is not supported by `NCCL`."
1716
DistributedUtils.force_initialize(
1817
MPIBackend; cuda_devices, force_cuda=true, caller="NCCLBackend", amdgpu_devices
1918
)

ext/SimpleChainsExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module SimpleChainsExt
22

3-
using ArgCheck: @argcheck
43
using SimpleChains: SimpleChains
54
using Random: AbstractRNG
65

@@ -99,7 +98,7 @@ function NNlib.logsoftmax!(
9998
x::Union{SimpleChains.StrideArray{T2,2},SimpleChains.PtrArray{T2,2}};
10099
dims=1,
101100
) where {T1,T2}
102-
@argcheck dims == 1
101+
@assert dims == 1
103102
m = similar(y, SimpleChains.static_size(y, 2))
104103
SimpleChains.logsoftmax!(y, m, x)
105104
return y

ext/ZygoteExt/ZygoteExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module ZygoteExt
22

3-
using ArgCheck: @argcheck
43
using ADTypes: AutoZygote
54
using ChainRulesCore: ChainRulesCore
65
using ForwardDiff: ForwardDiff

ext/ZygoteExt/batched_autodiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ function Lux.AutoDiffInternalImpl.batched_jacobian_impl(f::F, ::AutoZygote, x) w
33
# construct the Jacobian
44
y, pb_f = Zygote.pullback(f, x)
55

6-
@argcheck y isa AbstractArray MethodError
6+
@assert y isa AbstractArray "Expected output to be an AbstractArray, got $(typeof(y))"
77
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
88
throw(AssertionError("`batched_jacobian` only supports batched outputs \
99
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))

lib/LuxCUDA/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ CUDA = "5.8"
1313
Reexport = "1"
1414
cuDNN = "1.4.3"
1515
julia = "1.10"
16+
17+
[workspaces]
18+
projects = ["test"]

lib/LuxCUDA/test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
34
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
45

56
[compat]
67
Aqua = "0.8.4"
78
Test = "1.10"
9+
10+
[sources]
11+
LuxCUDA = {path = ".."}

0 commit comments

Comments
 (0)