Skip to content

Commit a645e38

Browse files
committed
Rewrite DNNFriction using Lux
1 parent 3f065de commit a645e38

File tree

7 files changed

+6618
-183
lines changed

7 files changed

+6618
-183
lines changed

Project.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,38 @@ version = "0.1.1"
77
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
88
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
99
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
10-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1110
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
1211
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1312
JSOSolvers = "10dff2fc-5484-5881-a0e0-c90441020f8a"
13+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1414
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
15+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1516
MadNLP = "2621e9c9-9eb4-46b1-8089-e8c72242dfb6"
1617
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
1718
ManualNLPModels = "30dfa513-9b2f-4fb3-9796-781eabac1617"
19+
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
20+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1821
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1922
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
23+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2024
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2125
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2226
Triangulate = "f7e6ffb2-c36d-4f8f-a77e-16e897189344"
2327

2428
[compat]
2529
ColorSchemes = "3.25"
2630
Enzyme = "0.13"
27-
Flux = "0.16"
28-
JLD2 = "0.6.2"
31+
JLD2 = "0.6"
2932
JSOSolvers = "0.14.1"
33+
Lux = "1.21"
3034
MAT = "0.10"
35+
MLUtils = "0.4.8"
3136
MadNLP = "0.8"
3237
Makie = "0.24"
3338
ManualNLPModels = "0.2.0"
39+
OnlineStats = "1.7.2"
40+
Optimisers = "0.4.6"
41+
Reactant = "0.2.170"
3442
SparseArrays = "1"
3543
StatsBase = "0.34"
3644
Triangulate = "3.0"

src/core/analyses/stressbalanceanalysis.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,8 @@ function UpdateElements(analysis::StressbalanceAnalysis,elements::Vector{Tria},
6666
FetchDataToInput(md,inputs,elements,md.friction.C,FrictionCEnum)
6767
FetchDataToInput(md,inputs,elements,md.friction.m,FrictionMEnum)
6868
FetchDataToInput(md,inputs,elements,md.friction.Cmax,FrictionCmaxEnum)
69-
elseif typeof(md.friction) == FluxDNNFriction
70-
FetchDataToInput(md,inputs,elements,md.geometry.ssx,SurfaceSlopeXEnum)
71-
FetchDataToInput(md,inputs,elements,md.geometry.ssy,SurfaceSlopeYEnum)
72-
FetchDataToInput(md,inputs,elements,md.geometry.bsx,BedSlopeXEnum)
73-
FetchDataToInput(md,inputs,elements,md.geometry.bsy,BedSlopeYEnum)
69+
elseif typeof(md.friction) == DNNFriction
70+
FetchDataToInput(md,inputs,elements,md.friction.C,FrictionCEnum)
7471
else
7572
error("Friction ", typeof(md.friction), " not supported yet")
7673
end
@@ -90,11 +87,11 @@ function UpdateParameters(analysis::StressbalanceAnalysis,parameters::Parameters
9087
AddParam(parameters, 2, FrictionLawEnum)
9188
elseif typeof(md.friction)==SchoofFriction
9289
AddParam(parameters, 11, FrictionLawEnum)
93-
elseif typeof(md.friction)==FluxDNNFriction
90+
elseif typeof(md.friction)==DNNFriction
9491
AddParam(parameters, 20, FrictionLawEnum)
95-
AddParam(parameters, md.friction.dnnChain, FrictionDNNChainEnum)
96-
AddParam(parameters, md.friction.dtx, FrictionDNNdtxEnum)
97-
AddParam(parameters, md.friction.dty, FrictionDNNdtyEnum)
92+
AddParam(parameters, md.friction.model, FrictionDNNEnum)
93+
AddParam(parameters, md.friction.ps, FrictionDNNpsEnum)
94+
AddParam(parameters, md.friction.st, FrictionDNNstEnum)
9895
else
9996
error("Friction ", typeof(md.friction), " not supported yet")
10097
end

src/core/friction.jl

Lines changed: 15 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mutable struct CoreBuddFriction <: CoreFriction #{{{
1313
rho_water::Float64
1414
g::Float64
1515
end# }}}
16-
struct CoreWeertmanFriction <: CoreFriction#{{{
16+
mutable struct CoreWeertmanFriction <: CoreFriction#{{{
1717
c_input::Input
1818
vx_input::Input
1919
vy_input::Input
@@ -31,18 +31,13 @@ mutable struct CoreSchoofFriction <: CoreFriction #{{{
3131
rho_water::Float64
3232
g::Float64
3333
end# }}}
34-
mutable struct CoreFluxDNNFriction <: CoreFriction#{{{
35-
dnnChain::Vector{Flux.Chain}
36-
dtx::Vector{StatsBase.ZScoreTransform}
37-
dty::Vector{StatsBase.ZScoreTransform}
38-
xyz_list::Matrix{Float64}
34+
mutable struct CoreDNNFriction <: CoreFriction#{{{
35+
model::AbstractLuxLayer
36+
ps
37+
st
38+
c_input::Input
3939
vx_input::Input
4040
vy_input::Input
41-
b_input::Input
42-
H_input::Input
43-
rho_ice::Float64
44-
rho_water::Float64
45-
g::Float64
4641
end# }}}
4742

4843
function CoreFriction(element::Tria, ::Val{frictionlaw}) where frictionlaw #{{{
@@ -79,23 +74,12 @@ function CoreFriction(element::Tria, ::Val{frictionlaw}) where frictionlaw #{{{
7974

8075
return CoreSchoofFriction(c_input, vx_input, vy_input, m_input, Cmax_input, H_input, b_input, rho_ice, rho_water, g)
8176
elseif frictionlaw==20
82-
dnnChain = FindParam(Vector{Flux.Chain{}}, element, FrictionDNNChainEnum)
83-
dtx = FindParam(Vector{StatsBase.ZScoreTransform{Float64, Vector{Float64}} }, element, FrictionDNNdtxEnum)
84-
dty = FindParam(Vector{StatsBase.ZScoreTransform{Float64, Vector{Float64}} }, element, FrictionDNNdtyEnum)
85-
H_input = GetInput(element, ThicknessEnum)
86-
b_input = GetInput(element, BaseEnum)
87-
ssx_input = GetInput(element, SurfaceSlopeXEnum)
88-
ssy_input = GetInput(element, SurfaceSlopeYEnum)
89-
bsx_input = GetInput(element, BedSlopeXEnum)
90-
bsy_input = GetInput(element, BedSlopeYEnum)
91-
92-
xyz_list = GetVerticesCoordinates(element.vertices)
93-
94-
rho_ice = FindParam(Float64, element, MaterialsRhoIceEnum)
95-
rho_water = FindParam(Float64, element, MaterialsRhoSeawaterEnum)
96-
g = FindParam(Float64, element, ConstantsGEnum)
77+
c_input = GetInput(element, FrictionCEnum)
78+
model = FindParam(AbstractLuxLayer, element, FrictionDNNEnum)
79+
ps = FindParam(NamedTuple, element, FrictionDNNpsEnum)
80+
st = FindParam(NamedTuple, element, FrictionDNNstEnum)
9781

98-
return CoreFluxDNNFriction(dnnChain,dtx,dty,xyz_list,vx_input,vy_input,b_input,H_input,rho_ice,rho_water,g)
82+
return CoreDNNFriction(model,ps,st,c_input,vx_input,vy_input)
9983
else
10084
error("Friction ",typeof(md.friction)," not supported yet")
10185
end
@@ -161,40 +145,18 @@ end#}}}
161145
end
162146
return alpha2
163147
end #}}}
164-
@inline function Alpha2(friction::CoreFluxDNNFriction, gauss::GaussTria, i::Int64)#{{{
165-
bed = GetInputValue(friction.b_input, gauss, i)
166-
H = GetInputValue(friction.H_input, gauss, i)
148+
@inline function Alpha2(friction::CoreDNNFriction, gauss::GaussTria, i::Int64)#{{{
167149
vx = GetInputValue(friction.vx_input, gauss, i)
168150
vy = GetInputValue(friction.vy_input, gauss, i)
169-
h = bed + H
170151

152+
c = GetInputValue(friction.c_input, gauss, i)
171153
# Get the velocity
172154
vmag = VelMag(friction, gauss, i)
173155

174-
# velocity gradients
175-
dvx = GetInputDerivativeValue(friction.vx_input,friction.xyz_list,gauss,i)
176-
dvy = GetInputDerivativeValue(friction.vy_input,friction.xyz_list,gauss,i)
177-
vxdx = dvx[1]
178-
vxdy = dvx[2]
179-
vydx = dvy[1]
180-
vydy = dvy[2]
181-
182-
# Get effective pressure
183-
Neff = EffectivePressure(friction, gauss, i)
184-
185-
# need to change according to the construction of DNN
186-
alpha2 = 0.0
187-
for i in 1:length(friction.dnnChain)
188-
xin = StatsBase.transform(friction.dtx[i], (reshape(vcat(vx, vy, vxdx, vxdy, vydx, vydy, bed, h), 8, :)))
189-
pred = StatsBase.reconstruct(friction.dty[i], friction.dnnChain[i](xin))
190-
alpha2 += first(pred)
191-
end
192-
# Average
193-
alpha2 = alpha2 / length(friction.dnnChain)
194-
if ( (vmag == 0.0) | (alpha2 < 0.0) )
156+
if (vmag == 0.0 )
195157
alpha2 = 0.0
196158
else
197-
alpha2 = alpha2 ./ vmag
159+
alpha2 = c^2*vmag^2
198160
end
199161
return alpha2
200162
end#}}}

0 commit comments

Comments
 (0)