Skip to content

Commit 584426f

Browse files
authored
Move all requires-based optional code to extension modules (#930)
1 parent d48e5ef commit 584426f

19 files changed

Lines changed: 145 additions & 150 deletions

.bumpversion.cfg

Lines changed: 0 additions & 6 deletions
This file was deleted.

Project.toml

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DFTK"
22
uuid = "acf6eb54-70d9-11e9-0013-234b7a5f5337"
33
authors = ["Michael F. Herbst <info@michael-herbst.com>", "Antoine Levitt <antoine.levitt@inria.fr>"]
4-
version = "0.6.14"
4+
version = "0.6.15"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -37,7 +37,6 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
3737
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3838
PseudoPotentialIO = "cb339c56-07fa-4cb2-923a-142469552264"
3939
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
40-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
4140
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
4241
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4342
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -49,45 +48,59 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
4948
UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
5049

5150
[weakdeps]
51+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
52+
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
53+
IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
54+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
5255
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
56+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
5357
Wannier = "2b19380a-1f7e-4d7d-b1b8-8aa60b3321c9"
5458
wannier90_jll = "c5400fa0-8d08-52c2-913f-1e3f656c1ce9"
5559
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"
5660

5761
[extensions]
62+
DFTKCUDAExt = "CUDA"
63+
DFTKGenericLinearAlgebraExt = "GenericLinearAlgebra"
64+
DFTKIntervalArithmeticExt = "IntervalArithmetic"
65+
DFTKJLD2Ext = "JLD2"
5866
DFTKJSON3Ext = "JSON3"
59-
DFTKWannierExt = "Wannier"
67+
DFTKPlotsExt = "Plots"
6068
DFTKWannier90Ext = "wannier90_jll"
61-
DFTKWriteVTK = "WriteVTK"
69+
DFTKWannierExt = "Wannier"
70+
DFTKWriteVTKExt = "WriteVTK"
6271

6372
[compat]
6473
AbstractFFTs = "1"
6574
Artifacts = "1"
6675
AtomsBase = "0.3.1"
6776
Brillouin = "0.5.14"
6877
ChainRulesCore = "1.15"
78+
CUDA = "5"
6979
Dates = "1"
7080
DftFunctionals = "0.2"
7181
DocStringExtensions = "0.9"
7282
FFTW = "1.5"
7383
ForwardDiff = "0.10"
84+
GenericLinearAlgebra = "0.3"
7485
GPUArraysCore = "0.1"
7586
Interpolations = "0.14, 0.15"
7687
IntervalArithmetic = "0.20"
77-
IterTools = "1"
7888
IterativeSolvers = "0.9"
89+
IterTools = "1"
90+
JLD2 = "0.4"
7991
JSON3 = "1"
8092
LazyArtifacts = "1.3"
8193
Libxc = "0.3.17"
82-
LineSearches = "7"
8394
LinearAlgebra = "1"
8495
LinearMaps = "3"
96+
LineSearches = "7"
8597
LoopVectorization = "0.12"
8698
MPI = "0.20.13"
8799
Markdown = "1"
88100
Optim = "1"
89101
OrderedCollections = "1"
90102
PeriodicTable = "1"
103+
Plots = "1"
91104
PkgVersion = "0.3"
92105
Polynomials = "3, 4"
93106
PrecompileTools = "1"
@@ -96,7 +109,6 @@ Primes = "0.5"
96109
Printf = "1"
97110
PseudoPotentialIO = "0.1"
98111
Random = "1"
99-
Requires = "1"
100112
Roots = "2"
101113
SparseArrays = "1"
102114
SpecialFunctions = "2"

examples/arbitrary_floattype.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ eltype(scfres.ρ)
5252
#
5353
# !!! note "Generic linear algebra routines"
5454
# For more unusual floating-point types (like IntervalArithmetic or DoubleFloats),
55-
# which are not directly supported in the standard `LinearAlgebra` library of Julia
56-
# one additional step is required: One needs to explicitly enable the generic versions
57-
# of standard linear-algebra operations like `cholesky` or `qr`, which are needed
58-
# inside DFTK by loading the `GenericLinearAlgebra` package in the user script
55+
# which are not directly supported in the standard `LinearAlgebra` and `FFTW`
56+
# libraries one additional step is required: One needs to explicitly enable the generic
57+
# versions of standard linear-algebra operations like `cholesky` or `qr` or standard
58+
# `fft` operations, which DFTK requires. THis is done by loading the
59+
# `GenericLinearAlgebra` package in the user script
5960
# (i.e. just add ad `using GenericLinearAlgebra` next to your `using DFTK` call).
6061
#
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
1-
synchronize_device(::GPU{<:CUDA.CuArray}) = CUDA.synchronize()
1+
module DFTKCUDAExt
2+
using CUDA
3+
import DFTK: GPU, DispatchFunctional
4+
using DftFunctionals
5+
using DFTK
6+
using Libxc
7+
8+
DFTK.synchronize_device(::GPU{<:CUDA.CuArray}) = CUDA.synchronize()
29

310
for fun in (:potential_terms, :kernel_terms)
411
@eval function DftFunctionals.$fun(fun::DispatchFunctional,
@@ -7,3 +14,5 @@ for fun in (:potential_terms, :kernel_terms)
714
$fun(fun.inner, ρ, args...)
815
end
916
end
17+
18+
end

src/workarounds/fft_generic.jl renamed to ext/DFTKGenericLinearAlgebraExt/DFTKGenericLinearAlgebraExt.jl

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
1-
include("FourierTransforms.jl/FourierTransforms.jl")
1+
module DFTKGenericLinearAlgebraExt
2+
using DFTK
3+
using DFTK: DummyInplace
4+
using LinearAlgebra
5+
using AbstractFFTs
6+
import AbstractFFTs: Plan, ScaledPlan,
7+
fft, ifft, bfft, fft!, ifft!, bfft!,
8+
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
9+
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft,
10+
fftshift, ifftshift,
11+
rfft_output_size, brfft_output_size,
12+
plan_inv, normalization
13+
import Base: show, summary, size, ndims, length, eltype, *, inv, \
14+
import LinearAlgebra: mul!
215

3-
# This is needed to flag that the fft_generic.jl file has already been loaded
4-
const GENERIC_FFT_LOADED = true
5-
6-
if !isdefined(Main, :GenericLinearAlgebra)
7-
@warn("Code paths for generic floating-point types activated in DFTK. Remember to " *
8-
"add 'using GenericLinearAlgebra' to your user script. " *
9-
"See https://docs.dftk.org/stable/examples/arbitrary_floattype/ for details.")
10-
end
16+
include("ctfft.jl") # Main file of FourierTransforms.jl
1117

1218
# Utility functions to setup FFTs for DFTK. Most functions in here
1319
# are needed to correct for the fact that FourierTransforms is not
1420
# yet fully compliant with the AbstractFFTs interface and has still
1521
# various bugs we work around.
1622

17-
function next_working_fft_size(::Any, size::Integer)
23+
function DFTK.next_working_fft_size(::Any, size::Integer)
1824
# TODO FourierTransforms has a bug, which is triggered
1925
# only in some factorizations, see
2026
# https://github.com/JuliaComputing/FourierTransforms.jl/issues/10
2127
nextpow(2, size) # We fall back to powers of two to be safe
2228
end
23-
default_primes(::Any) = (2, )
29+
DFTK.default_primes(::Any) = (2, )
2430

2531
# Generic fallback function, Float32 and Float64 specialization in fft.jl
26-
function build_fft_plans!(tmp::AbstractArray{<:Complex})
32+
function DFTK.build_fft_plans!(tmp::AbstractArray{<:Complex})
2733
# Note: FourierTransforms has no support for in-place FFTs at the moment
2834
# ... also it's extension to multi-dimensional arrays is broken and
2935
# the algo only works for some cases
@@ -40,14 +46,12 @@ function build_fft_plans!(tmp::AbstractArray{<:Complex})
4046
ipFFT, opFFT, ipBFFT, opBFFT
4147
end
4248

43-
44-
4549
struct GenericPlan{T}
4650
subplans
4751
factor::T
4852
end
4953

50-
function generic_apply(p::GenericPlan, X::AbstractArray)
54+
function Base.:*(p::GenericPlan, X::AbstractArray)
5155
pl1, pl2, pl3 = p.subplans
5256
ret = similar(X)
5357
for i = 1:size(X, 1), j = 1:size(X, 2)
@@ -65,21 +69,21 @@ end
6569
LinearAlgebra.mul!(Y, p::GenericPlan, X) = Y .= p * X
6670
LinearAlgebra.ldiv!(Y, p::GenericPlan, X) = Y .= p \ X
6771

68-
import Base: *, \, inv, length
6972
length(p::GenericPlan) = prod(length, p.subplans)
70-
*(p::GenericPlan, X::AbstractArray) = generic_apply(p, X)
7173
*(p::GenericPlan{T}, fac::Number) where {T} = GenericPlan{T}(p.subplans, p.factor * T(fac))
7274
*(fac::Number, p::GenericPlan{T}) where {T} = p * fac
7375
\(p::GenericPlan, X) = inv(p) * X
7476
inv(p::GenericPlan{T}) where {T} = GenericPlan{T}(inv.(p.subplans), 1 / p.factor)
7577

7678
function generic_plan_fft(data::AbstractArray{T, 3}) where {T}
77-
GenericPlan{T}([FourierTransforms.plan_fft(data[:, 1, 1]),
78-
FourierTransforms.plan_fft(data[1, :, 1]),
79-
FourierTransforms.plan_fft(data[1, 1, :])], T(1))
79+
GenericPlan{T}([plan_fft(data[:, 1, 1]),
80+
plan_fft(data[1, :, 1]),
81+
plan_fft(data[1, 1, :])], T(1))
8082
end
8183
function generic_plan_bfft(data::AbstractArray{T, 3}) where {T}
82-
GenericPlan{T}([FourierTransforms.plan_bfft(data[:, 1, 1]),
83-
FourierTransforms.plan_bfft(data[1, :, 1]),
84-
FourierTransforms.plan_bfft(data[1, 1, :])], T(1))
84+
GenericPlan{T}([plan_bfft(data[:, 1, 1]),
85+
plan_bfft(data[1, :, 1]),
86+
plan_bfft(data[1, 1, :])], T(1))
87+
end
88+
8589
end

src/workarounds/FourierTransforms.jl/ctfft.jl renamed to ext/DFTKGenericLinearAlgebraExt/ctfft.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,31 @@ using Primes
22

33
# COV_EXCL_START
44

5+
# These files are copied (and very slightly modified) from [FourierTransforms.jl](https://github.com/JuliaComputing/FourierTransforms.jl)
6+
# commit [6a206bcfc8f49a129ca34beaf57d05cdc148dc8a](https://github.com/JuliaComputing/FourierTransforms.jl/tree/6a206bcfc8f49a129ca34beaf57d05cdc148dc8a).
7+
# For their license see [LICENCE.md](LICENCE.md).
8+
9+
# The following licence applies:
10+
# Copyright (c) 2017-2019: Steven G. Johnson, Yingbo Ma, and Julia Computing.
11+
#
12+
# Permission is hereby granted, free of charge, to any person obtaining a copy
13+
# of this software and associated documentation files (the "Software"), to deal
14+
# in the Software without restriction, including without limitation the rights
15+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16+
# copies of the Software, and to permit persons to whom the Software is
17+
# furnished to do so, subject to the following conditions:
18+
#
19+
# The above copyright notice and this permission notice shall be included in all
20+
# copies or substantial portions of the Software.
21+
#
22+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28+
# SOFTWARE.
29+
530
# 1d Cooley-Tukey FFTs, using an FFTW-like (version 1) approach: automatic
631
# generation of fixed-size FFT kernels (with and without twiddle factors)
732
# which are combined to make arbitrary-size FFTs (plus generic base
@@ -301,7 +326,8 @@ macro nontwiddle(args...)
301326
end
302327
@assert isa(T, Type)
303328
quote
304-
function FourierTransforms.applystep(ns::NontwiddleKernelStep{T,$n,$forward},
329+
function DFTKGenericLinearAlgebraExt.applystep(
330+
ns::NontwiddleKernelStep{T,$n,$forward},
305331
vn::Integer,
306332
X::AbstractArray{T},
307333
x0::Integer, xs::Integer, xvs::Integer,
@@ -340,7 +366,8 @@ macro twiddle(args...)
340366
end
341367
@assert isa(T, Type)
342368
quote
343-
function FourierTransforms.applystep(ts::TwiddleKernelStep{T,$n,$forward},
369+
function DFTKGenericLinearAlgebraExt.applystep(
370+
ts::TwiddleKernelStep{T,$n,$forward},
344371
vn::Integer,
345372
X::AbstractArray{T},
346373
x0::Integer, xs::Integer, xvs::Integer,
Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1+
module DFTKIntervalArithmeticExt
2+
using DFTK
3+
using IntervalArithmetic
4+
using LinearAlgebra
5+
import DFTK: symmetry_operations, _is_well_conditioned, compute_Glims_fast
6+
import DFTK: local_potential_fourier
7+
import IntervalArithmetic: Interval
18
import SpecialFunctions: erfc
2-
const Interval = IntervalArithmetic.Interval
39

410
# Monkey-patch a few functions for Intervals
511
# ... this is far from proper and a bit specific for our use case here
612
# (that's why it's not contributed upstream).
713
# should be done e.g. by changing the rounding mode ...
814
erfc(i::Interval) = Interval(prevfloat(erfc(i.lo)), nextfloat(erfc(i.hi)))
15+
Base.nextfloat(x::Interval) = Interval(nextfloat(x.lo), nextfloat(x.hi))
16+
Base.prevfloat(x::Interval) = Interval(prevfloat(x.lo), prevfloat(x.hi))
917

1018
# This is done to avoid using sincospi(x), called by cispi(x),
1119
# which has not been implemented in IntervalArithmetic
1220
# see issue #513 on IntervalArithmetic repository
13-
cis2pi(x::Interval) = exp(2 * (pi * (im * x)))
21+
DFTK.cis2pi(x::Interval) = exp(2 * (pi * (im * x)))
1422

15-
Base.nextfloat(x::Interval) = Interval(nextfloat(x.lo), nextfloat(x.hi))
16-
Base.prevfloat(x::Interval) = Interval(prevfloat(x.lo), prevfloat(x.hi))
17-
value_type(::Type{<:Interval{T}}) where {T} = T
23+
DFTK.value_type(::Type{<:Interval{T}}) where {T} = T
1824

1925
function compute_Glims_fast(lattice::AbstractMatrix{<:Interval}, args...; kwargs...)
2026
# This is done to avoid a call like ceil(Int, ::Interval)
@@ -25,7 +31,7 @@ function compute_Glims_fast(lattice::AbstractMatrix{<:Interval}, args...; kwargs
2531
# their midpoints should be good.
2632
compute_Glims_fast(IntervalArithmetic.mid.(lattice), args...; kwargs...)
2733
end
28-
function compute_Glims_precise(::AbstractMatrix{<:Interval}, args...; kwargs...)
34+
function DFTK.compute_Glims_precise(::AbstractMatrix{<:Interval}, args...; kwargs...)
2935
error("fft_size_algorithm :precise not supported with intervals")
3036
end
3137

@@ -36,8 +42,8 @@ function _is_well_conditioned(A::AbstractArray{<:Interval}; kwargs...)
3642
end
3743

3844
function symmetry_operations(lattice::AbstractMatrix{<:Interval}, atoms, positions,
39-
magnetic_moments=[];
40-
tol_symmetry=max(SYMMETRY_TOLERANCE, maximum(radius, lattice)))
45+
magnetic_moments=[];
46+
tol_symmetry=max(SYMMETRY_TOLERANCE, maximum(radius, lattice)))
4147
@assert tol_symmetry < 1e-2
4248
symmetry_operations(IntervalArithmetic.mid.(lattice), atoms, positions, magnetic_moments;
4349
tol_symmetry)
@@ -50,10 +56,13 @@ function local_potential_fourier(el::ElementCohenBergstresser, q::T) where {T <:
5056
T(local_potential_fourier(el, IntervalArithmetic.mid(q)))
5157
end
5258

53-
function estimate_integer_lattice_bounds(M::AbstractMatrix{<:Interval}, δ, shift=zeros(3))
59+
function DFTK.estimate_integer_lattice_bounds(M::AbstractMatrix{<:Interval}, δ,
60+
shift=zeros(3))
5461
# As a general statement, with M a lattice matrix, then if ||Mx|| <= δ,
5562
# then xi = <ei, M^-1 Mx> = <M^-T ei, Mx> <= ||M^-T ei|| δ.
5663
# Below code does not support non-3D systems.
5764
xlims = [norm(inv(M')[:, i]) * δ + shift[i] for i = 1:3]
5865
map(x -> ceil(Int, x.hi), xlims)
5966
end
67+
68+
end
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
function save_scfres_master(file::AbstractString, scfres::NamedTuple, ::Val{:jld2})
1+
module DFTKJLD2Ext
2+
using DFTK
3+
using DFTK: energy_hamiltonian, AbstractArchitecture, AbstractKgrid
4+
using JLD2
5+
using MPI
6+
7+
function DFTK.save_scfres_master(file::AbstractString, scfres::NamedTuple, ::Val{:jld2})
28
!mpi_master() && error(
39
"This function should only be called on MPI master after the k-point data has " *
410
"been gathered with `gather_kpts`."
@@ -19,7 +25,7 @@ function save_scfres_master(file::AbstractString, scfres::NamedTuple, ::Val{:jld
1925
end
2026

2127

22-
function load_scfres(jld::JLD2.JLDFile)
28+
function DFTK.load_scfres(jld::JLD2.JLDFile)
2329
basis = jld["basis"]
2430
scfdict = Dict{Symbol, Any}(
2531
=> jld["ρ"],
@@ -45,7 +51,7 @@ function load_scfres(jld::JLD2.JLDFile)
4551
scfdict[:ham] = ham
4652
(; (sym => scfdict[sym] for sym in jld["__propertynames"])...)
4753
end
48-
load_scfres(file::AbstractString) = JLD2.jldopen(load_scfres, file, "r")
54+
DFTK.load_scfres(file::AbstractString) = JLD2.jldopen(DFTK.load_scfres, file, "r")
4955

5056

5157
#
@@ -93,3 +99,5 @@ function Base.convert(::Type{PlaneWaveBasis{T,T,Arch,GT,RT,KGT}},
9399
MPI.COMM_WORLD,
94100
serial.architecture)
95101
end
102+
103+
end

0 commit comments

Comments
 (0)