Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
f70b1b7
features: rename `OrbitalFieldMatrix` to `OrbitalFeatureDescriptor`
thazhemadam Jul 4, 2021
3163da9
codec: rename OneHotOneCold.jl to onehotonecold.jl
thazhemadam Jul 7, 2021
69bd611
codec: create a simple codec - `SimpleCodec`
thazhemadam Jul 8, 2021
f6737fb
elementfeature: add constructor for building `ElementFD` from name,codec
thazhemadam Jul 9, 2021
2378190
features: simplify encoding-decoding logic, and minimize repetitive code
thazhemadam Jul 9, 2021
9161923
Merge branch 'move_bov' of github.com:Chemellia/ChemistryFeaturizatio…
thazhemadam Jul 10, 2021
f1b4588
data:tabulate_pmg_data.py: update script to include Electronic Structure
thazhemadam Jul 10, 2021
a2f34b5
data: pymatgen_atom_data.csv: update to include Electronic Structure
thazhemadam Jul 10, 2021
21f2d11
test: ElementFeature_tests: reorganize tests into testsets
thazhemadam Jul 15, 2021
23393a2
Merge branch 'main' of github.com:Chemellia/ChemistryFeaturization.jl…
thazhemadam Jul 15, 2021
7e6ad96
atoms: add new generic `elements()` function
thazhemadam Jul 15, 2021
b03a159
test: add tests and improve coverage
thazhemadam Jul 16, 2021
23b1ecf
features: make `output_shape` generic
thazhemadam Jul 16, 2021
0004aed
orbitalfeaturedescriptor: basic sketch
thazhemadam Jul 16, 2021
fd42c7e
orbitalfeature: annotate, fix order of shells in lookup in defaultdecode
thazhemadam Jul 18, 2021
120c4c1
orbitalfeature: trim down `lookup_table`
thazhemadam Jul 18, 2021
d21c568
data: pymatgen_atom_data: standardize `Lawrencium`'s configuration value
thazhemadam Jul 18, 2021
d799766
atoms: export `elements` at the main module level
thazhemadam Jul 18, 2021
a32599d
data.jl: create new module - `Data`
thazhemadam Jul 19, 2021
374904b
data.jl: remove const-ness for `atom_data_df` and `feature_info`
thazhemadam Jul 19, 2021
5a732f8
orbitalfeature: remove `lookup_table` field
thazhemadam Jul 19, 2021
f02bb76
orbitalfeature: add custom `Base.show` methods
thazhemadam Jul 19, 2021
ccb8ba1
add Zygote to deps
DhairyaLGandhi Aug 5, 2021
678e1de
refactor weight_cutoff
DhairyaLGandhi Aug 5, 2021
17d93f0
basic adjoints - TODO - move them to Zygote
DhairyaLGandhi Aug 5, 2021
1c551e9
fixes to cutoff adjoint
DhairyaLGandhi Aug 10, 2021
bd6fe29
add backwards lengths in adjoint
DhairyaLGandhi Aug 10, 2021
cbb55ea
make fn amenable to FiniteDifferences
DhairyaLGandhi Aug 10, 2021
65d9bc0
add tests for AD
DhairyaLGandhi Aug 10, 2021
faa6648
add FiniteDiff to test deps
DhairyaLGandhi Aug 10, 2021
68b90fb
fix tests
DhairyaLGandhi Aug 11, 2021
d2ff77c
fix equality
DhairyaLGandhi Aug 11, 2021
0ac1c19
Add Xtals dep, sketch out neighbor list function
rkurchin Aug 11, 2021
e011a6b
update sketch to use NearestNeighbors.jl
rkurchin Aug 11, 2021
eb806cb
simplify cutoff adjoint
DhairyaLGandhi Aug 12, 2021
f262cf1
finitedifferences fix
DhairyaLGandhi Aug 12, 2021
46cadf8
better tests
DhairyaLGandhi Aug 12, 2021
bb12d1c
fixes
DhairyaLGandhi Aug 12, 2021
85d898f
fix typo
DhairyaLGandhi Aug 12, 2021
872e5fc
another typo
DhairyaLGandhi Aug 12, 2021
73268d9
refactor tests
DhairyaLGandhi Aug 12, 2021
46498dd
cleanup
DhairyaLGandhi Aug 12, 2021
53b4e85
give it another shot
DhairyaLGandhi Aug 17, 2021
01d5c33
Update README.md
ViralBShah Aug 19, 2021
7b14382
Update README.md
ViralBShah Aug 19, 2021
81c61ef
basic neighbor list implementtion is there
rkurchin Aug 23, 2021
b0d6b1a
formatting
rkurchin Aug 23, 2021
f5dc410
Should be ready to test autodiff on this!
rkurchin Aug 25, 2021
909a41c
formatting
rkurchin Aug 25, 2021
4f42fe0
compat and docs build to julia 1.6 to match Xtals
rkurchin Aug 25, 2021
74375f4
add spaces between badges
DhairyaLGandhi Aug 26, 2021
3b5f92a
Update src/utils/graph_building.jl
rkurchin Aug 28, 2021
6e45306
don't change rc (this means we need absolute paths to cifs), add chec…
rkurchin Aug 30, 2021
c66ee74
Merge branch 'graph_ad' of https://github.com/Chemellia/ChemistryFeat…
rkurchin Aug 30, 2021
64edcd4
formatting
rkurchin Aug 30, 2021
1b8cb56
minor version bump
rkurchin Aug 30, 2021
324b828
update changelog
rkurchin Aug 30, 2021
b548f60
construct AtomGraph from Crystal
rkurchin Aug 30, 2021
79c6287
docs: remove explicit `versions` so it can possibly publish `dev` too?
thazhemadam Sep 1, 2021
94fbbd0
Merge pull request #115 from Chemellia/at/doc-dev-publish
thazhemadam Sep 1, 2021
9887a06
docs: add overview section for types, tweak `terminology and philosophy`
thazhemadam Sep 1, 2021
8e55598
docs: add more info into docs for types, minor restructuring and links
thazhemadam Sep 1, 2021
02a9b4f
docs: add examples for what a feature could be
thazhemadam Sep 2, 2021
f6f5ab5
orbitalfeature:clean up default_ofd_decode & rename valence_shell_config
thazhemadam Aug 21, 2021
fb2c4ff
orbitalfeaturedescriptor: add docstrings
thazhemadam Sep 2, 2021
7d155a9
tests: add `OrbitalFeatureUtils` tests
thazhemadam Sep 2, 2021
3f5dc40
orbitalfeatureutils: make `df::DataFrame` an optional argument
thazhemadam Sep 2, 2021
9c581c1
features: fix bug in `output_shape` by renaming `fd` to `efd`
thazhemadam Sep 2, 2021
ba4f8c4
Merge pull request #116 from Chemellia/at/docs
thazhemadam Sep 2, 2021
4565893
Merge pull request #100 from Chemellia/at/orbital-fd
thazhemadam Sep 2, 2021
82b04a9
changelog: include changes made in #100
thazhemadam Sep 2, 2021
68fdf54
Merge branch 'main' into graph_ad
rkurchin Sep 2, 2021
03d96d2
fix Atoms namespace conflict in AtomGraph tests...this probably needs…
rkurchin Sep 2, 2021
35a4c13
formatting
rkurchin Sep 2, 2021
8b6d05e
add other AtomGraph constructor to docs
rkurchin Sep 3, 2021
d04e46c
Merge pull request #112 from Chemellia/graph_ad
rkurchin Sep 7, 2021
aa7d310
update compats and changelog
rkurchin Sep 7, 2021
991f878
add Zygote to deps
DhairyaLGandhi Aug 5, 2021
84b1834
refactor weight_cutoff
DhairyaLGandhi Aug 5, 2021
c683a55
basic adjoints - TODO - move them to Zygote
DhairyaLGandhi Aug 5, 2021
ec29db1
fixes to cutoff adjoint
DhairyaLGandhi Aug 10, 2021
bb0e2c6
add backwards lengths in adjoint
DhairyaLGandhi Aug 10, 2021
7b17f24
make fn amenable to FiniteDifferences
DhairyaLGandhi Aug 10, 2021
4261c2d
add tests for AD
DhairyaLGandhi Aug 10, 2021
504dd2f
add FiniteDiff to test deps
DhairyaLGandhi Aug 10, 2021
b65b853
fix tests
DhairyaLGandhi Aug 11, 2021
d53f9af
fix equality
DhairyaLGandhi Aug 11, 2021
8011d10
simplify cutoff adjoint
DhairyaLGandhi Aug 12, 2021
e97a375
finitedifferences fix
DhairyaLGandhi Aug 12, 2021
3bd7996
better tests
DhairyaLGandhi Aug 12, 2021
77db19a
fixes
DhairyaLGandhi Aug 12, 2021
1662228
fix typo
DhairyaLGandhi Aug 12, 2021
6e78367
another typo
DhairyaLGandhi Aug 12, 2021
c17ce49
refactor tests
DhairyaLGandhi Aug 12, 2021
496475f
cleanup
DhairyaLGandhi Aug 12, 2021
2f5ab3d
give it another shot
DhairyaLGandhi Aug 17, 2021
b1194d2
rebase
DhairyaLGandhi Sep 14, 2021
8f5606f
add code but dont include
DhairyaLGandhi Sep 16, 2021
239a2b6
cleanups for unneeded adjoints
DhairyaLGandhi Sep 17, 2021
1f9529f
fixes
DhairyaLGandhi Sep 17, 2021
72b4542
refactor + array distance
DhairyaLGandhi Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CSV = "0.7, 0.8"
Expand All @@ -35,7 +36,9 @@ SimpleWeightedGraphs = "1"
julia = "1.4, 1.5, 1.6"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["FiniteDifferences", "Test", "Zygote"]
60 changes: 60 additions & 0 deletions src/utils/adjoints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using Zygote # , ChainRulesCore
using Zygote: @adjoint
using LinearAlgebra

@adjoint function Base.Generator(f, iter)
ys, backs = Zygote.unzip([Zygote.pullback(f, x) for x in iter])
Base.Generator(f, iter), Δ -> begin
b(d::Dict) = [back(v)[1] for (back,v) in zip(backs,values(d))]
b(nt::NamedTuple{(:f, :iter)}) = [back(i)[1] for (back,i) in zip(backs,nt.iter)]
(nothing, b(Δ))
end
end

@adjoint function Base.Iterators.Zip(is)
Zip_pullback(Δ) = (Zygote.unzip(Δ),)
return Base.Iterators.Zip(is), Zip_pullback
end

@adjoint function Pair(a, b)
Pair(a, b), Δ -> (Δ, nothing)
end

@adjoint function Dict(g::Base.Generator)
ys, backs = Zygote.unzip([Zygote.pullback(g.f, args) for args in g.iter])
Dict(ys...), Δ -> begin
∂d = first(backs)(Δ)[1]
d = Dict(ys...)
for (k,v) in pairs(d)
d[k] = _zero(v)
end
(merge(d, ∂d), )
end
end

_zero(x) = zero(x)
_zero(::Nothing) = nothing

@adjoint function _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists;
max_num_nbr = 12)
y, ld = _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists;
max_num_nbr = max_num_nbr)
function cutoff_pb((Δ,nt))
s = size(Δ)
Δ = vec(collect(Δ))
for (ix,(_,_,d)) in zip(eachindex(Δ), ijd)
y_, back_ = Zygote.pullback(f, d)
Δ[ix] *= back_(Δ[ix])[1]
end
(reshape(Δ, s), nothing,
collect(zip(fill(nothing, size(Δ,1)),
fill(nothing, size(Δ,1)),
Δ)),
nothing,
nothing)
end

(y,ld), cutoff_pb
end
31 changes: 23 additions & 8 deletions src/utils/graph_building.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export inverse_square, exp_decay
using PyCall
using ChemistryFeaturization
using Serialization
using Zygote

# options for decay of bond weights with distance...
# user can of course write their own as well
Expand Down Expand Up @@ -91,20 +92,32 @@ function weights_cutoff(is, js, dists; max_num_nbr = 12, dist_decay_func = inver

# iterate over list of tuples to build edge weights...
# note that neighbor list double counts so we only have to increment one counter per pair
weight_mat = zeros(Float32, num_atoms, num_atoms)
weight_mat = zeros(Float64, round(Int,num_atoms), round(Int,num_atoms))
weight_mat, longest_dists = _cutoff!(weight_mat,
dist_decay_func,
ijd,
nb_counts,
longest_dists)
# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')
end


function _cutoff!(weight_mat, f, ijd,
nb_counts, longest_dists; max_num_nbr = 12)

for (i, j, d) in ijd
# if we're under the max OR if it's at the same distance as the previous one
if nb_counts[i] < max_num_nbr || isapprox(longest_dists[i], d)
weight_mat[i, j] += dist_decay_func(d)
longest_dists[i] = d
nb_counts[i] += 1
if nb_counts[round(Int,i)] < max_num_nbr || isapprox(longest_dists[round(Int,i)], d)
weight_mat[round(Int,i), round(Int,j)] += f(d)
longest_dists[round(Int,i)] = d
nb_counts[round(Int,i)] += 1
end
end

# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')
weight_mat, longest_dists
end


"""
Build graph using neighbors from faces of Voronoi polyedra and weights from areas. Based on the approach from https://github.com/ulissigroup/uncertainty_benchmarking/blob/aabb407807e35b5fd6ad06b14b440609ae09e6ef/BNN/data_pyro.py#L268
"""
Expand Down Expand Up @@ -135,4 +148,6 @@ end

# TODO: graphs from SMILES via OpenSMILES.jl

include("adjoints.jl")

end
23 changes: 23 additions & 0 deletions test/utils/GraphBuilding_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using ChemistryFeaturization.Utils.GraphBuilding
using Zygote, FiniteDifferences

@testset "GraphBuilding" begin
adj, els = build_graph(
Expand Down Expand Up @@ -39,3 +40,25 @@ using ChemistryFeaturization.Utils.GraphBuilding
@test all(isapprox.(adj[2:5, 1], 1.0, atol = 1e-4))
@test all(isapprox.(adj[3:2, 2], 0.375, atol = 1e-5))
end

@testset "Graph Building AD tests" begin

function test_fd(i, j, dist)
fd = grad(forward_fdm(2,1),
(i,j,dist) -> sum(GraphBuilding.weights_cutoff(i,j,dist)),
i, j, dist)

gs = gradient(i, j, dist) do i, j, dist
sum(GraphBuilding.weights_cutoff(i, j, dist))
end

@test gs[1] == fill(nothing, length(i))
@test gs[2] == fill(nothing, length(j))
@test gs[3] ≈ fd[3]
end

# test with non-overlapping indices
test_fd(collect(1:10), collect(1:10), Float64.(collect(1:10)))
# test with overlapping indices
test_fd(rand(1:10, 100), rand(1:10, 100), rand(100))
end