Skip to content

Commit 01bdd55

Browse files
committed
Mooncake feature has released
1 parent b84e56c commit 01bdd55

File tree

5 files changed

+88
-81
lines changed

5 files changed

+88
-81
lines changed

.github/workflows/Test.yml

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,29 @@ jobs:
2828
fail-fast: true # TODO: toggle
2929
matrix:
3030
version:
31-
# - '1.10'
31+
- '1.10'
3232
- '1.11'
3333
- '1.12'
3434
group:
3535
- Core/Internals
36-
# - Back/DifferentiateWith
37-
# - Core/SimpleFiniteDiff
38-
# - Back/SparsityDetector
39-
# - Core/ZeroBackends
40-
# - Back/ChainRules
36+
- Back/DifferentiateWith
37+
- Core/SimpleFiniteDiff
38+
- Back/SparsityDetector
39+
- Core/ZeroBackends
40+
- Back/ChainRules
4141
# - Back/Diffractor
42-
# - Back/Enzyme
43-
# - Back/FastDifferentiation
44-
# - Back/FiniteDiff
45-
# - Back/FiniteDifferences
46-
# - Back/ForwardDiff
47-
# - Back/GTPSA
42+
- Back/Enzyme
43+
- Back/FastDifferentiation
44+
- Back/FiniteDiff
45+
- Back/FiniteDifferences
46+
- Back/ForwardDiff
47+
- Back/GTPSA
4848
- Back/Mooncake
49-
# - Back/PolyesterForwardDiff
50-
# - Back/ReverseDiff
51-
# - Back/Symbolics
52-
# - Back/Tracker
53-
# - Back/Zygote
49+
- Back/PolyesterForwardDiff
50+
- Back/ReverseDiff
51+
- Back/Symbolics
52+
- Back/Tracker
53+
- Back/Zygote
5454
skip_lts:
5555
- ${{ github.event.pull_request.draft }}
5656
skip_pre:
@@ -104,61 +104,61 @@ jobs:
104104
token: ${{ secrets.CODECOV_TOKEN }}
105105
fail_ci_if_error: false
106106

107-
# test-DIT:
108-
# name: ${{ matrix.version }} - DIT (${{ matrix.group }})
109-
# runs-on: ubuntu-latest
110-
# if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
111-
# timeout-minutes: 60
112-
# permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
113-
# actions: write
114-
# contents: read
115-
# strategy:
116-
# fail-fast: true
117-
# matrix:
118-
# version:
119-
# - '1.10'
120-
# - '1.11'
121-
# - '1.12'
122-
# group:
123-
# - Formalities
124-
# - Zero
125-
# - Standard
126-
# - Weird
127-
# skip_lts:
128-
# - ${{ github.event.pull_request.draft }}
129-
# skip_pre:
130-
# - ${{ github.event.pull_request.draft }}
131-
# exclude:
132-
# - skip_lts: true
133-
# version: '1.10'
134-
# - skip_pre: true
135-
# version: '1.12'
136-
# env:
137-
# JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
138-
# JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
139-
# steps:
140-
# - uses: actions/checkout@v5
141-
# - uses: julia-actions/setup-julia@v2
142-
# with:
143-
# version: ${{ matrix.version }}
144-
# arch: x64
145-
# - uses: julia-actions/cache@v2
146-
# - name: Install dependencies & run tests
147-
# run: julia --project=./DifferentiationInterfaceTest --color=yes -e '
148-
# using Pkg;
149-
# Pkg.Registry.update();
150-
# Pkg.develop(path="./DifferentiationInterface");
151-
# if ENV["JULIA_DI_PR_DRAFT"] == "true";
152-
# Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]);
153-
# else;
154-
# Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true);
155-
# end;'
156-
# - uses: julia-actions/julia-processcoverage@v1
157-
# with:
158-
# directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
159-
# - uses: codecov/codecov-action@v5
160-
# with:
161-
# files: lcov.info
162-
# flags: DIT
163-
# token: ${{ secrets.CODECOV_TOKEN }}
164-
# fail_ci_if_error: false
107+
test-DIT:
108+
name: ${{ matrix.version }} - DIT (${{ matrix.group }})
109+
runs-on: ubuntu-latest
110+
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
111+
timeout-minutes: 60
112+
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
113+
actions: write
114+
contents: read
115+
strategy:
116+
fail-fast: true
117+
matrix:
118+
version:
119+
- '1.10'
120+
- '1.11'
121+
- '1.12'
122+
group:
123+
- Formalities
124+
- Zero
125+
- Standard
126+
- Weird
127+
skip_lts:
128+
- ${{ github.event.pull_request.draft }}
129+
skip_pre:
130+
- ${{ github.event.pull_request.draft }}
131+
exclude:
132+
- skip_lts: true
133+
version: '1.10'
134+
- skip_pre: true
135+
version: '1.12'
136+
env:
137+
JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
138+
JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
139+
steps:
140+
- uses: actions/checkout@v5
141+
- uses: julia-actions/setup-julia@v2
142+
with:
143+
version: ${{ matrix.version }}
144+
arch: x64
145+
- uses: julia-actions/cache@v2
146+
- name: Install dependencies & run tests
147+
run: julia --project=./DifferentiationInterfaceTest --color=yes -e '
148+
using Pkg;
149+
Pkg.Registry.update();
150+
Pkg.develop(path="./DifferentiationInterface");
151+
if ENV["JULIA_DI_PR_DRAFT"] == "true";
152+
Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]);
153+
else;
154+
Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true);
155+
end;'
156+
- uses: julia-actions/julia-processcoverage@v1
157+
with:
158+
directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
159+
- uses: codecov/codecov-action@v5
160+
with:
161+
files: lcov.info
162+
flags: DIT
163+
token: ${{ secrets.CODECOV_TOKEN }}
164+
fail_ci_if_error: false

DifferentiationInterface/Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4141
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4242
DifferentiationInterfaceGTPSAExt = "GTPSA"
4343
DifferentiationInterfaceMooncakeExt = "Mooncake"
44-
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
44+
DifferentiationInterfacePolyesterForwardDiffExt = [
45+
"PolyesterForwardDiff",
46+
"ForwardDiff",
47+
"DiffResults",
48+
]
4549
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4650
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
4751
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
@@ -65,7 +69,7 @@ ForwardDiff = "0.10.36,1"
6569
GPUArraysCore = "0.2"
6670
GTPSA = "1.4.0"
6771
LinearAlgebra = "1"
68-
Mooncake = "0.4.147"
72+
Mooncake = "0.4.175"
6973
PolyesterForwardDiff = "0.1.2"
7074
ReverseDiff = "1.15.1"
7175
SparseArrays = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ function DI.prepare_pullback_nokwarg(
1717
)
1818
y = f(x, map(DI.unwrap, contexts)...)
1919
dy_righttype = zero_tangent(y)
20+
contexts_tup_false = map(_ -> false, contexts)
2021
args_to_zero = (
2122
false, # f
2223
true, # x
23-
map(_ -> false, contexts)...,
24+
contexts_tup_false...,
2425
)
2526
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
2627
return prep
@@ -123,10 +124,11 @@ function DI.prepare_gradient_nokwarg(
123124
cache = prepare_gradient_cache(
124125
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
125126
)
127+
contexts_tup_false = map(_ -> false, contexts)
126128
args_to_zero = (
127129
false, # f
128130
true, # x
129-
map(_ -> false, contexts)...,
131+
contexts_tup_false...,
130132
)
131133
prep = MooncakeGradientPrep(_sig, cache, args_to_zero)
132134
return prep

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ function DI.prepare_pullback_nokwarg(
3131
silence_debug_messages = config.silence_debug_messages,
3232
)
3333
dy_righttype_after = zero_tangent(y)
34+
contexts_tup_false = map(_ -> false, contexts)
3435
args_to_zero = (
3536
false, # target_function
3637
false, # f!
3738
false, # y
3839
true, # x
39-
map(_ -> false, contexts)...,
40+
contexts_tup_false...,
4041
)
4142
prep = MooncakeTwoArgPullbackPrep(
4243
_sig, cache, dy_righttype_after, target_function, args_to_zero

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Pkg
2-
Pkg.add(url = "https://github.com/gdalle/Mooncake.jl", rev = "selective_zeroing")
2+
Pkg.add("Mooncake")
33

44
using DifferentiationInterface, DifferentiationInterfaceTest
55
using Mooncake: Mooncake

0 commit comments

Comments
 (0)