Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/ACEpotentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import ACEpotentials.Models: algebraic_smoothness_prior,
exp_smoothness_prior,
gaussian_smoothness_prior,
set_parameters!,
fast_evaluator,
@committee,
set_committee!
import JSON
Expand Down
85 changes: 85 additions & 0 deletions src/models/Rnl_learnable_new.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

import EquivariantTensors as ET
using StaticArrays
using Lux

# In ET we currently store an edge xij as a NamedTuple, e.g,
# xij = (𝐫ij = ..., zi = ..., zj = ...)
# The NTtransform is a wrapper for mapping xij -> y
# (in this case y = transformed distance) adding logic to enable
# differentiation through this operation.
#
# In ET.Atoms edges are of the form xij = (𝐫 = ..., s0 = ..., s1 = ...)


# build a pure Lux Rnl basis 100% compatible with LearnableRnlrzz
#

function _convert_Rnl_learnable(basis; zlist = basis._i2z,
rfun = x -> x.r )

# number of species
NZ = length(zlist)

# species z -> index i mapping
__z2i = let _i2z = (_i2z = zlist,)
z -> _z2i(_i2z, z)
end

# __zz2i maps a `(Zi, Zj)` pair to a single index `a` representing
# (Zi, Zj) in a flattened array
__zz2ii = (zi, zj) -> (__z2i(zi) - 1) * NZ + __z2i(zj)
selector = xij -> __zz2ii(xij.s0, xij.s1)

# construct the transform to be a Lux layer that behaves a bit
# like a WrappedFunction, but with additional support for
# named-tuple inputs
#
et_trans = let transforms = basis.transforms
ET.NTtransform( xij -> begin
trans_ij = transforms[__z2i(xij.s0), __z2i(xij.s1)]
return trans_ij(rfun(xij))
end )
end

# the envelope is always a simple quartic (1 -x^2)^2
# otherwise make this transform fail.
# ( note the transforms is normalized to map to [-1, 1]
# y outside [-1, 1] maps to 1 or -1. )
# this obviously needs to be relaxed if we want compatibility
# with older versions of the code
for env in basis.envelopes
@assert env isa PolyEnvelope2sX
@assert env.p1 == env.p2 == 2
@assert env.x1 == -1
@assert env.x2 == 1
end

et_env = y -> (1 - y^2)^2

# the polynomial basis just stays the same
#
et_polys = basis.polys

# the linear layer transformation
# P(yij) -> W[(Zi, Zj)] * P(yij)
# with W[a] learnable weights
#
et_linl = ET.SelectLinL(length(et_polys), # indim
length(basis.spec), # outdim
NZ^2, # num (Zi,Zj) pairs
selector)

et_rbasis = SkipConnection( # input is (rij, zi, zj)
Chain(y = et_trans, # transforms yij
P = SkipConnection(
et_polys,
WrappedFunction( Py -> et_env.(Py[2]) .* Py[1] )
)
), # r -> y -> P = e(y) * polys(y)
et_linl # P -> W(Zi, Zj) * P
)

return et_rbasis
end

1 change: 0 additions & 1 deletion src/models/ace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ function evaluate(model::ACEModel,

# contract with params
val = dot(B, (@view ps.WB[:, i_z0]))


# -------------------
# pair potential
Expand Down
4 changes: 3 additions & 1 deletion src/models/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ include("Rnl_basis.jl")
include("Rnl_learnable.jl")
include("Rnl_splines.jl")

include("Rnl_learnable_new.jl")

# sparse.jl removed - now using EquivariantTensors.SparseACEbasis directly

include("ace_heuristics.jl")
Expand All @@ -45,7 +47,7 @@ include("smoothness_priors.jl")

include("utils.jl")

include("fasteval.jl")
# include("fasteval.jl")



Expand Down
5 changes: 0 additions & 5 deletions src/models/radial_envelopes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,9 @@ abstract type AbstractEnvelope end
struct PolyEnvelope1sR{T}
rcut::T
p::Int
# -------
meta::Dict{String, Any}
end


PolyEnvelope1sR(rcut, p) =
PolyEnvelope1sR(rcut, p, Dict{String, Any}())

function evaluate(env::PolyEnvelope1sR, r::T, x::T) where T
if r >= env.rcut
return zero(T)
Expand Down
43 changes: 39 additions & 4 deletions test/models/test_learnable_Rnl.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@

using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
using TestEnv; TestEnv.activate();
Pkg.develop("/Users/ortner/gits/EquivariantTensors.jl/")


# using Pkg; Pkg.activate(joinpath(@__DIR__(), "..", ".."))
# using TestEnv; TestEnv.activate();
##

using ACEpotentials
M = ACEpotentials.Models

using Random, LuxCore, Test, LinearAlgebra, ACEbase
using Polynomials4ML.Testing: print_tf
using Polynomials4ML.Testing: print_tf, println_slim
rng = Random.MersenneTwister(1234)

Random.seed!(1234)
Expand Down Expand Up @@ -83,3 +84,37 @@ Rnl_spl, ∇Rnl_spl = M.evaluate_ed(basis_spl, r, Zi, Zj, ps_spl, st_spl)
println_slim(@test norm(Rnl - Rnl_spl, Inf) < 1e-4 )
println_slim(@test norm(∇Rnl - ∇Rnl_spl, Inf) < 1e-2 )

##
#
# test the conversion to a Lux style Rnl basis
#
et_rbasis = M._convert_Rnl_learnable(basis)
et_ps, et_st = LuxCore.setup(Random.default_rng(), et_rbasis)

et_ps.connection.W[:, :, 1] = ps.Wnlq[:, :, 1, 1]
et_ps.connection.W[:, :, 2] = ps.Wnlq[:, :, 1, 2]
et_ps.connection.W[:, :, 3] = ps.Wnlq[:, :, 2, 1]
et_ps.connection.W[:, :, 4] = ps.Wnlq[:, :, 2, 2]

for ntest = 1:50
global ps, st, et_ps, et_st
r = 2.0 + 5 * rand()
Zi = rand(basis._i2z)
Zj = rand(basis._i2z)
xij = ( r = r, s0 = Zi, s1 = Zj )
R1 = basis(r, Zi, Zj, ps, st)
R2 = et_rbasis( xij, et_ps, et_st)[1]
print_tf(@test R1 ≈ R2)
end

# batched test
for ntest = 1:10
z0 = rand(basis._i2z)
xx = [ (r = 2.0 + 2 * rand(), s0 = z0, s1 = rand(basis._i2z)) for _ in 1:30 ]
rr = [ x.r for x in xx ]
Zjs = [ x.s1 for x in xx ]
R1 = M.evaluate_batched(basis, rr, z0, Zjs, ps, st)
R2 = et_rbasis( xx, et_ps, et_st)[1]
print_tf(@test R1 ≈ R2)
end

3 changes: 3 additions & 0 deletions test/models/test_radial_transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@ for trans in [trans_2_2, trans_2_4, trans_1_3]
println_slim( @test ACEpotentials.Models.test_normalized_transform(trans_2_2) )
end

##


199 changes: 199 additions & 0 deletions test/new_backend.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
using Pkg; Pkg.activate(joinpath(@__DIR__(), ".."))
using TestEnv; TestEnv.activate();
# Pkg.develop("/Users/ortner/gits/EquivariantTensors.jl/")

##

using ACEpotentials
M = ACEpotentials.Models

# build a pure Lux Rnl basis compatible with LearnableRnlrzz
import EquivariantTensors as ET
import Polynomials4ML as P4ML

using StaticArrays, Lux
using AtomsBase, AtomsBuilder, Unitful, AtomsCalculators

using Random, LuxCore, Test, LinearAlgebra, ACEbase
using Polynomials4ML.Testing: print_tf, println_slim
rng = Random.MersenneTwister(1234)

Random.seed!(1234)

##

elements = (:Si, :O)
level = M.TotalDegree()
max_level = 10
order = 3
maxl = 6

# modify rin0cuts to have same cutoff for all elements
# TODO: there is currently a bug with variable cutoffs
# (?is there? The radials seem fine? check again)
rin0cuts = M._default_rin0cuts(elements)
rin0cuts = (x -> (rin = x.rin, r0 = x.r0, rcut = 5.5)).(rin0cuts)


model = M.ace_model(; elements = elements, order = order,
Ytype = :solid, level = level, max_level = max_level,
maxl = maxl, pair_maxn = max_level,
rin0cuts = rin0cuts,
init_WB = :glorot_normal, init_Wpair = :glorot_normal)

ps, st = Lux.setup(rng, model)

# Missing issues:
# Vref = 0 => this will not be tested
# pair potential will also not be tested

# kill the pair basis for now
for s in model.pairbasis.splines
s.itp.itp.coefs[:] *= 0
end

##
# build the Rnl basis
# here we build it from the model.rbasis, so we can exactly match it
# but in the final implementation we will have to create it directly

rbasis = model.rbasis
et_i2z = AtomsBase.ChemicalSpecies.(rbasis._i2z)
et_rbasis = M._convert_Rnl_learnable(rbasis; zlist = et_i2z,
rfun = x -> norm(x.𝐫) )

# TODO: this is cheating, but this set can probably be generated quite
# easily as part of the construction of et_rbasis.
et_rspec = rbasis.spec

##
# build the ybasis

et_ybasis = Chain( 𝐫ij = ET.NTtransform(x -> x.𝐫),
Y = model.ybasis )
et_yspec = P4ML.natural_indices(et_ybasis.layers.Y)

# combining the Rnl and Ylm basis we can build an embedding layer
et_embed = ET.EdgeEmbed( BranchLayer(;
Rnl = et_rbasis,
Ylm = et_ybasis ) )

##
# now build the linear ACE layer

# Convert AA_spec from (n,l,m) format to (n,l) format for mb_spec
AA_spec = model.tensor.meta["𝔸spec"]
et_mb_spec = unique([[(n=b.n, l=b.l) for b in bb] for bb in AA_spec])

et_mb_basis = ET.sparse_equivariant_tensor(
L = 0, # Invariant (scalar) output only
mb_spec = et_mb_spec,
Rnl_spec = et_rspec,
Ylm_spec = et_yspec,
basis = real
)

# et_acel = ET.SparseACElayer(et_mb_basis, (1,))

# ------------------------------------------------
# readout layer : need to select which linear output to
# use based on the center atom species

__zi = let zlist = (_i2z = et_i2z, )
x -> M._z2i(zlist, x.s)
end

et_readout = ET.SelectLinL(
et_mb_basis.lens[1], # input dim
1, # output dim
length(et_i2z), # num species
__zi )

# finally build the full model from the two layers
#
# TODO: there is a huge problem here; the read-out layer needs to know
# about the center species; need to figure out how to pass that information
# through to the ace layer
#

__sz(::Any) = nothing
__sz(A::AbstractArray) = size(A)
__sz(x::Tuple) = __sz.(x)
dbglayer(msg = ""; show=false) = WrappedFunction(x ->
begin
println("$msg : ", typeof(x), ", ", __sz(x))
if show; display(x); end
return x
end )

et_basis = Lux.Chain(;
embed = et_embed, # embedding layer
ace = et_mb_basis, # ACE layer -> basis
unwrp = WrappedFunction(x -> x[1]), # unwrap the tuple
)

et_model = Lux.Chain(
L1 = Lux.BranchLayer(;
basis = et_basis,
nodes = WrappedFunction(G -> G.node_data), # pass node data through
),
Ei = et_readout,
E = WrappedFunction(sum), # sum up to get a total energy
)
et_ps, et_st = LuxCore.setup(MersenneTwister(1234), et_model)

##
# fixup all the parameters to make sure they match
# the basis ordering appears to be identical, but it is not clear it really
# is because meta["mb_spec"] only gives the original ordering before basis
# construction ...
nnll = M.get_nnll_spec(model.tensor)
et_nnll = et_mb_basis.meta["mb_spec"]
@show nnll == et_nnll

# but this is also identical ...
@show model.tensor.A2Bmaps[1] == et_mb_basis.A2Bmaps[1]

# radial basis parameters
et_ps.L1.basis.embed.Rnl.connection.W[:, :, 1] = ps.rbasis.Wnlq[:, :, 1, 1]
et_ps.L1.basis.embed.Rnl.connection.W[:, :, 2] = ps.rbasis.Wnlq[:, :, 1, 2]
et_ps.L1.basis.embed.Rnl.connection.W[:, :, 3] = ps.rbasis.Wnlq[:, :, 2, 1]
et_ps.L1.basis.embed.Rnl.connection.W[:, :, 4] = ps.rbasis.Wnlq[:, :, 2, 2]

# many-body basis parameters; because the readout layer doesn't know about
# species yet we take a single parameter set; this needs to be fixed asap.
# ps.WB[:, 2] .= ps.WB[:, 1]

et_ps.Ei.W[1, :, 1] .= ps.WB[:, 1]
et_ps.Ei.W[1, :, 2] .= ps.WB[:, 2]

##

# wrap the old ACE model into a calculator
calc_model = ACEpotentials.ACEPotential(model, ps, st)

# we will also need to get the cutoff radius which we didn't track
# (Another TODO!!!)
rcut = maximum(a.rcut for a in model.pairbasis.rin0cuts)

function rand_struct()
sys = AtomsBuilder.bulk(:Si) * (2,1,1)
rattle!(sys, 0.2u"Å")
AtomsBuilder.randz!(sys, [:Si => 0.5, :O => 0.5])
return sys
end

function energy_new(sys, et_model)
G = ET.Atoms.interaction_graph(sys, rcut * u"Å")
return et_model(G, et_ps, et_st)[1]
end

##

for ntest = 1:30
sys = rand_struct()
G = ET.Atoms.interaction_graph(sys, rcut * u"Å")
E1 = AtomsCalculators.potential_energy(sys, calc_model)
E2 = energy_new(sys, et_model)
print_tf( @test abs(ustrip(E1) - ustrip(E2)) < 1e-5 )
end
Loading