diff --git a/src/Model.jl b/src/Model.jl
index cd4c5cedd..ddf93f374 100644
--- a/src/Model.jl
+++ b/src/Model.jl
@@ -110,6 +110,8 @@ function Model(lattice::AbstractMatrix{T},
symmetries=default_symmetries(lattice, atoms, positions, magnetic_moments,
spin_polarization, terms),
) where {T <: Real}
+ T2 = promote_type(T, typeof(temperature), eltype(magnetic_moments))
+
# Validate εF and n_electrons
if !isnothing(εF) # fixed Fermi level
if !isnothing(n_electrons)
@@ -133,7 +135,7 @@ function Model(lattice::AbstractMatrix{T},
atom_groups = [findall(Ref(pot) .== atoms) for pot in Set(atoms)]
# Special handling of 1D and 2D systems, and sanity checks
- lattice = Mat3{T}(lattice)
+ lattice = Mat3{T2}(lattice)
n_dim = count(!iszero, eachcol(lattice))
n_dim > 0 || error("Check your lattice; we do not do 0D systems")
for i = n_dim+1:3
@@ -159,7 +161,7 @@ function Model(lattice::AbstractMatrix{T},
)
n_spin = length(spin_components(spin_polarization))
- temperature = T(austrip(temperature))
+ temperature = T2(austrip(temperature))
temperature < 0 && error("temperature must be non-negative")
if !allunique(string.(nameof.(typeof.(terms))))
@@ -175,7 +177,7 @@ function Model(lattice::AbstractMatrix{T},
end
@assert !isempty(symmetries) # Identity has to be always present.
- Model{T,value_type(T)}(model_name,
+ Model{T2,value_type(T2)}(model_name,
lattice, recip_lattice, n_dim, inv_lattice, inv_recip_lattice,
unit_cell_volume, recip_cell_volume,
n_electrons, εF, spin_polarization, n_spin, temperature, smearing,
diff --git a/src/workarounds/forwarddiff_rules.jl b/src/workarounds/forwarddiff_rules.jl
index 26f946db2..b9edebb15 100644
--- a/src/workarounds/forwarddiff_rules.jl
+++ b/src/workarounds/forwarddiff_rules.jl
@@ -166,7 +166,7 @@ function construct_value(model::Model{T}) where {T <: Dual}
newpositions;
model.model_name,
model.n_electrons,
- magnetic_moments=[], # Symmetries given explicitly
+ magnetic_moments=value_type(T)[], # Symmetries given explicitly
terms=model.term_types,
temperature=ForwardDiff.value(model.temperature),
model.smearing,
# Very basic setup, useful for testing
using DFTK
using PseudoPotentialData
function get(T)
a = 10.26 # Silicon lattice constant in Bohr
lattice = a / 2 * [[0 1 1.];
[1 0 1.];
[1 1 0.]]
Si = ElementPsp(:Si, PseudoFamily("cp2k.nc.sr.lda.v0_1.semicore.gth"))
atoms = [Si, Si]
positions = [ones(3)/8, -ones(3)/8]
model = model_DFT(lattice, atoms, positions; functionals=LDA(), temperature=T)
basis = PlaneWaveBasis(model; Ecut=15, kgrid=[1, 1, 1])
scfres = self_consistent_field(basis)
scfres.ρ[1]
end
# get(.01)
using ForwardDiff
res = ForwardDiff.derivative(get, .01)
gets me zero, because there's no explicit variation dH wrt temperature. I think we might need to generalize solve_ΩplusK_split to accept a docc and return a docc.
Needs the following patch
then
gets me zero, because there's no explicit variation dH wrt temperature. I think we might need to generalize solve_ΩplusK_split to accept a docc and return a docc.