Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ GeometryOptimization = "673bf261-a53d-43b9-876f-d3c1fc8329c2"
IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Wannier = "2b19380a-1f7e-4d7d-b1b8-8aa60b3321c9"
WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"
Expand All @@ -66,6 +67,7 @@ DFTKIntervalArithmeticExt = "IntervalArithmetic"
DFTKJLD2Ext = "JLD2"
DFTKJSON3Ext = "JSON3"
DFTKPlotsExt = "Plots"
DFTKMakieExt = "Makie"
DFTKWannier90Ext = "wannier90_jll"
DFTKWannierExt = "Wannier"
DFTKWriteVTKExt = "WriteVTK"
Expand Down Expand Up @@ -113,6 +115,7 @@ Optim = "1"
PeriodicTable = "1"
PkgVersion = "0.3"
Plots = "1"
Makie = "0.24.8"
PrecompileTools = "1"
Preferences = "1"
Printf = "1"
Expand Down Expand Up @@ -156,6 +159,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -167,4 +171,4 @@ WriteVTK = "64499a7a-5c06-52f2-abe2-ccb03c286192"
wannier90_jll = "c5400fa0-8d08-52c2-913f-1e3f656c1ce9"

[targets]
test = ["Test", "TestItemRunner", "AMDGPU", "ASEconvert", "AtomsBuilder", "Aqua", "AtomsIO", "AtomsIOPython", "CUDA", "CUDA_Runtime_jll", "ComponentArrays", "DoubleFloats", "FiniteDiff", "FiniteDifferences", "GenericLinearAlgebra", "GeometryOptimization", "IntervalArithmetic", "JLD2", "JSON3", "Logging", "Plots", "PythonCall", "QuadGK", "Random", "KrylovKit", "Wannier", "WriteVTK", "wannier90_jll"]
test = ["Test", "TestItemRunner", "AMDGPU", "ASEconvert", "AtomsBuilder", "Aqua", "AtomsIO", "AtomsIOPython", "CUDA", "CUDA_Runtime_jll", "ComponentArrays", "DoubleFloats", "FiniteDiff", "FiniteDifferences", "GenericLinearAlgebra", "GeometryOptimization", "IntervalArithmetic", "JLD2", "JSON3", "Logging", "Plots", "Makie", "PythonCall", "QuadGK", "Random", "KrylovKit", "Wannier", "WriteVTK", "wannier90_jll"]
144 changes: 144 additions & 0 deletions ext/DFTKMakieExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
module DFTKMakieExt

using DFTK
using DFTK: is_metal, data_for_plotting, spin_components, default_band_εrange
using Makie
using Unitful
using UnitfulAtomic
using LinearAlgebra

# Import DFTK stub functions so we can extend them natively
import DFTK: plot_bandstructure, plot_dos, plot_spin_3d, plot_spin_slice, plot_bandstructure!, plot_dos!, plot_spin_3d!, plot_spin_slice!

@info "DFTKMakieExt successfully loaded!"

function plot_spin_3d!(pos, scfres; density_threshold=0.05, arrow_scale=1.0, stride=1.0,
title_text="Magnetization", tip_len=0.8, tip_rad=0.35, line_rad=0.25)

data = DFTK.get_spin_3d_data(scfres; density_threshold=density_threshold, stride=stride)

ax = Axis3(pos[1, 1], title=title_text, aspect=:data,
xlabel="x (Bohr)", ylabel="y (Bohr)", zlabel="z (Bohr)",
elevation=pi/6, azimuth=pi/4)

if !isempty(data.X)
arrows3d!(ax, data.X, data.Y, data.Z, data.U, data.V, data.W;
lengthscale=arrow_scale, color=data.mags, colormap=:plasma,
colorrange=(0, data.max_mag),
tiplength=tip_len,
tipradius=tip_rad,
shaftradius=line_rad)
end
Colorbar(pos[1, 2], limits=(0, data.max_mag), colormap=:plasma, label="Magnetization (μB)")
return ax
end

function plot_spin_3d(scfres; kwargs...)
fig = Figure(size=(1000, 800), fontsize=20)
plot_spin_3d!(fig[1, 1], scfres; kwargs...)
return fig
end

function plot_spin_slice!(pos, scfres; axis=:z, stride=2, scale=1.5,
title_text="Spin Slice", tip_len=14, tip_wid=12, line_wid=2.0)

data = DFTK.get_spin_slice_data(scfres; axis=axis, stride=stride, scale=scale)
if !data.has_spin; error("No spin data found."); end

ax = Axis(pos[1, 1], title=title_text, xlabel=data.xl, ylabel=data.yl, aspect=DataAspect())
hm = heatmap!(ax, data.X_axis, data.Y_axis, data.h_data,
colormap=:balance, colorrange=(-data.clim_val, data.clim_val))
Colorbar(pos[1, 2], hm, label="Out-of-Plane Spin")

if !isempty(data.X_ar)
arrows2d!(ax, data.X_ar, data.Y_ar, data.U_ar, data.V_ar;
tiplength=tip_len, tipwidth=tip_wid, shaftwidth=line_wid, color=:black)
end
return ax
end

function plot_spin_slice(scfres; kwargs...)
fig = Figure(size=(800, 700), fontsize=18)
plot_spin_slice!(fig[1, 1], scfres; kwargs...)
return fig
end

function plot_bandstructure!(pos, band_data; title_text="Band Structure", y_limits=nothing)
data = DFTK.data_for_plotting(band_data)
eshift = something(band_data.εF, 0.0)
to_unit = ustrip(auconvert(u"eV", 1.0))

ax = Makie.Axis(pos, title=title_text, xlabel="Wave Vector",
ylabel=isnothing(band_data.εF) ? "Energy (eV)" : "Energy - εF (eV)")

for σ = 1:data.n_spin, iband = 1:data.n_bands, branch in data.kbranches
energies = (data.eigenvalues[:, iband, σ][branch] .- eshift) .* to_unit
Makie.lines!(ax, data.kdistances[branch], energies,
color = σ == 1 ? :blue : :red, linewidth=2)
end

for branch in data.kbranches[1:end-1]
Makie.vlines!(ax, [data.kdistances[last(branch)]], color=:black, linewidth=1)
end

ax.xticks = (data.ticks.distances, data.ticks.labels)
Makie.xlims!(ax, 0, data.kdistances[end])

if !isnothing(band_data.εF)
Makie.hlines!(ax, [0.0], color=:green, linewidth=2, linestyle=:dash)
end

# --- NEW: Apply custom Y-axis limits if provided ---
if !isnothing(y_limits)
Makie.ylims!(ax, y_limits[1], y_limits[2])
end

return ax
end

function plot_bandstructure(band_data; kwargs...)
fig = Figure(size=(800, 600), fontsize=18)
plot_bandstructure!(fig[1, 1], band_data; kwargs...)
return fig
end

function plot_dos!(pos, scfres; n_points=500, title_text="Density of States")
εF = scfres.εF
eshift = something(εF, 0.0)
to_unit = ustrip(auconvert(u"eV", 1.0))

εrange = DFTK.default_band_εrange(scfres.eigenvalues; εF=εF)
εs = range(εrange[1], εrange[2], length=n_points)

Dεs = DFTK.compute_dos.(εs, Ref(scfres.basis), Ref(scfres.eigenvalues);
smearing=scfres.basis.model.smearing,
temperature=scfres.basis.model.temperature)

energies = (εs .- eshift) .* to_unit
n_spin = scfres.basis.model.n_spin_components

ax = Axis(pos, title=title_text,
xlabel=isnothing(εF) ? "Energy (eV)" : "Energy - εF (eV)",
ylabel="Density of States")

for σ = 1:n_spin
D = [Dσ[σ] for Dσ in Dεs]
lines!(ax, energies, D, label="Spin $σ", color = σ == 1 ? :blue : :red, linewidth=2)
end

if !isnothing(εF)
vlines!(ax, [0.0], color=:green, linewidth=2, linestyle=:dash)
end
if n_spin > 1
axislegend(ax)
end
return ax
end

function plot_dos(scfres; kwargs...)
fig = Figure(size=(800, 600), fontsize=18)
plot_dos!(fig[1, 1], scfres; kwargs...)
return fig
end

end # module
112 changes: 109 additions & 3 deletions ext/DFTKPlotsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ using AtomsBase
using Brillouin: KPath
using DFTK
using DFTK: is_metal, data_for_plotting, spin_components, default_band_εrange
import DFTK: plot_dos, plot_bandstructure, plot_ldos, plot_pdos
import DFTK: plot_dos, plot_bandstructure, plot_ldos, plot_pdos, plot_spin_slice
using Plots
using Unitful
using UnitfulAtomic
using LinearAlgebra


function plot_bandstructure(basis::PlaneWaveBasis,
Expand All @@ -21,8 +22,7 @@ function plot_bandstructure(band_data::NamedTuple;
unit=u"hartree", kwargs_plot=(; ), kwargs...)
# TODO Replace by a plot recipe once BandData is its own type.

mpi_nprocs(band_data.basis.comm_kpts) > 1 &&
error("Band structure plotting with MPI not supported yet")
mpi_nprocs() > 1 && error("Band structure plotting with MPI not supported yet")

if !haskey(band_data, :kinter)
@warn("Calling plot_bandstructure without first computing the band data " *
Expand Down Expand Up @@ -184,4 +184,110 @@ function plot_pdos(basis::PlaneWaveBasis{T}, eigenvalues, ψ; iatom=nothing, lab
end
plot_pdos(scfres; kwargs...) = plot_pdos(scfres.basis, scfres.eigenvalues, scfres.ψ; scfres.εF, kwargs...)

"""
plot_spin_slice(scfres; axis=:z, slice_index=nothing, stride=1, scale=1.0, title="")

Plots a 2D slice of the spin density.
- `axis`: The normal axis to the slice (`:x`, `:y`, or `:z`).
- `slice_index`: The grid index of the slice (defaults to the middle of the cell).
- `stride`: Subsampling factor for arrows (use 2 or 3 to reduce clutter).
- `scale`: Length multiplier for the magnetic arrows.
"""
function plot_spin_slice(basis, ρ; axis=:z, slice_index=nothing, scale=0.5, stride=1, title="", kwargs...)
model = basis.model

# 1. Handle Non-Spin Cases
if model.spin_polarization in (:none, :spinless)
return Plots.plot(title="No Spin Polarization", grid=false, showaxis=false)
end

# 2. Extract Components
if model.spin_polarization == :collinear
# For collinear, we only have Z-component magnetization (Up - Down)
# We set Mx and My to zero so the code structure remains generic.
mx = zeros(size(ρ,1), size(ρ,2), size(ρ,3))
my = zeros(size(ρ,1), size(ρ,2), size(ρ,3))
mz = ρ[:, :, :, 1] .- ρ[:, :, :, 2]
else
mx, my, mz = ρ[:, :, :, 2], ρ[:, :, :, 3], ρ[:, :, :, 4]
end

# 3. Slice the Data based on the requested axis
dims = size(mx)

if axis == :z
k = isnothing(slice_index) ? dims[3]÷2 : slice_index
# Heatmap = Out-of-plane (Mz), Arrows = In-plane (Mx, My)
h_data = mz[:, :, k]
u_data = mx[:, :, k]
v_data = my[:, :, k]
xl, yl = "x (Bohr)", "y (Bohr)"
heatmap_title = "Color: Mz (Out-of-Plane)"

elseif axis == :y
j = isnothing(slice_index) ? dims[2]÷2 : slice_index
# Heatmap = Out-of-plane (My), Arrows = In-plane (Mx, Mz)
h_data = my[:, j, :]
u_data = mx[:, j, :]
v_data = mz[:, j, :]
xl, yl = "x (Bohr)", "z (Bohr)"
heatmap_title = "Color: My (Out-of-Plane)"

elseif axis == :x
i = isnothing(slice_index) ? dims[1]÷2 : slice_index
# Heatmap = Out-of-plane (Mx), Arrows = In-plane (My, Mz)
h_data = mx[i, :, :]
u_data = my[i, :, :]
v_data = mz[i, :, :]
xl, yl = "y (Bohr)", "z (Bohr)"
heatmap_title = "Color: Mx (Out-of-Plane)"
end

nx, ny = size(h_data)

# 4. Generate the Heatmap (The "Background" Scalar Field)
# We use :balance (Blue-White-Red) to clearly show Positive vs Negative domains
limit = maximum(abs.(h_data))
if limit < 1e-6; limit = 1.0; end

p = Plots.heatmap(1:nx, 1:ny, h_data',
c=:balance,
clims=(-limit, limit),
title=title,
xlabel=xl, ylabel=yl,
aspect_ratio=:equal,
colorbar_title=heatmap_title,
right_margin=15Plots.mm,
size=(800, 700),
dpi=300,
kwargs...
)

# 5. Generate the Quiver Arrows (The "In-Plane" Vector Field)
# We subsample using 'stride' to prevent the plot from becoming a black blob
X, Y, U, V = Float64[], Float64[], Float64[], Float64[]

for x in 1:stride:nx, y in 1:stride:ny
u, v = u_data[x, y], v_data[x, y]
mag = sqrt(u^2 + v^2)

# Only draw arrows if there is significant magnetization
if mag > 1e-4
push!(X, x)
push!(Y, y)
push!(U, u * scale * 5) # Scale factor for visibility
push!(V, v * scale * 5)
end
end

if !isempty(X)
Plots.quiver!(p, X, Y, quiver=(U, V), color=:black, linewidth=1.2)
end

return p
end

# Tuple Catcher
plot_spin_slice(scfres; kwargs...) = plot_spin_slice(scfres.basis, scfres.ρ; kwargs...)

end
15 changes: 15 additions & 0 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,21 @@ include("workarounds/forwarddiff_rules.jl")
include("gpu/linalg.jl")
include("gpu/gpu_arrays.jl")

# NEW ADDED for `spin_polarization = :full` feature
#\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
# Spin comparsion for `spin_polarization = :full` calculation support
export get_spin_3d_data
export get_spin_slice_data
include("postprocess/spin_extraction.jl")

export plot_spin_3d
export plot_spin_slice
export plot_spin_3d!
export plot_spin_slice!
export plot_bandstructure!
export plot_dos!
#\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\

# Precompilation block with a basic workflow

function precompilation_workflow(lattice, atoms, positions, magnetic_moments;
Expand Down
16 changes: 9 additions & 7 deletions src/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ struct Model{T <: Real, VT <: Real}
# :collinear Spin is polarized, but everywhere in the same direction.
# αα ̸= ββ, αβ = βα = 0, 2 spin components treated
# :full Generic magnetization, non-uniform direction.
# αβ, βα, αα, ββ all nonzero, different (not supported)
# αβ, βα, αα, ββ all nonzero, different. Wavefunctions are 2-component spinors.
# :spinless No spin at all ("spinless fermions", "mathematicians' electrons").
# The difference with :none is that the occupations are 1 instead of 2
spin_polarization::Symbol
n_spin_components::Int # 2 if :collinear, 1 otherwise
n_spin_components::Int # 2 if :collinear or :full, 1 otherwise

# If temperature==0, no fractional occupations are used.
# If temperature is nonzero, the occupations are
Expand Down Expand Up @@ -100,7 +100,7 @@ a [`PlaneWaveBasis`](@ref) from a `Model`.
else `None()`):
Smearing function used to compute occupations from Kohn-Sham eigenvalues.
- `spin_polarization::Symbol` (default: `determine_spin_polarization(magnetic_moments)`):
Controls spin treatment; allowed values are `:none`, `:collinear`, `:spinless`.
Controls spin treatment; allowed values are `:none`, `:collinear`, `:full`, `:spinless`.
- `symmetries` (default: `true`):
- `true`: run automatic symmetry detection with [`default_symmetries`](@ref).
- `false`: disable all symmetries.
Expand Down Expand Up @@ -189,7 +189,9 @@ function Model(lattice::AbstractMatrix{Tstatic},
# Spin polarization
spin_polarization in (:none, :collinear, :full, :spinless) ||
error("Only :none, :collinear, :full and :spinless allowed for spin_polarization")
spin_polarization == :full && error("Full spin polarization not yet supported")

# NOTE: The explicit block throwing an error for `:full` has been removed.

!isempty(magnetic_moments) && !(spin_polarization in (:collinear, :full)) && @warn(
"Non-empty magnetic_moments on a Model without spin polarization detected."
)
Expand Down Expand Up @@ -350,7 +352,7 @@ end
Maximal occupation of a state (2 for non-spin-polarized electrons, 1 otherwise).
"""
function filled_occupation(model)
if model.spin_polarization in (:spinless, :collinear)
if model.spin_polarization in (:spinless, :collinear, :full)
return 1
elseif model.spin_polarization == :none
return 2
Expand All @@ -367,7 +369,7 @@ function spin_components(spin_polarization::Symbol)
spin_polarization == :collinear && return (:up, :down )
spin_polarization == :none && return (:both, )
spin_polarization == :spinless && return (:spinless, )
spin_polarization == :full && return (:undefined, )
spin_polarization == :full && return (:full, )
end
spin_components(model::Model) = spin_components(model.spin_polarization)

Expand Down Expand Up @@ -434,4 +436,4 @@ comatrix_cart_to_red(model::Model) = _gen_matmatmul(model.lattice', model.in
matrix_red_to_cart(model::Model, Ared) = matrix_red_to_cart(model)(Ared)
matrix_cart_to_red(model::Model, Acart) = matrix_cart_to_red(model)(Acart)
comatrix_red_to_cart(model::Model, Bred) = comatrix_red_to_cart(model)(Bred)
comatrix_cart_to_red(model::Model, Bcart) = comatrix_cart_to_red(model)(Bcart)
comatrix_cart_to_red(model::Model, Bcart) = comatrix_cart_to_red(model)(Bcart)
Loading