Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
85fe716
feat: Implement XAS (X-ray Absorption Spectroscopy) model, fitting, l…
anyangml Mar 24, 2026
9e9c6a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
8fd99ad
feat: Reimplement XAS loss with per-atom property fitting, removing p…
anyangml Mar 24, 2026
9352c4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
9bc38d7
feat: Add X-ray Absorption Spectroscopy (XAS) training examples
anyangml Mar 24, 2026
c8a4005
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2026
e157ed7
feat: Implement XAS energy normalization in the XAS loss function and…
anyangml Mar 25, 2026
8c21612
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2026
250168b
fix:device
anyangml Mar 25, 2026
8ab20b2
fix: filter loss-related keys from state dict in inference and ignore…
anyangml Mar 30, 2026
38c3a04
fix: update XAS reference extraction path and ignore tests directory …
anyangml Mar 30, 2026
17ffd5b
feat: add weighted loss and smoothness regularization to XAS training…
anyangml Mar 30, 2026
829048e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2026
f81f2a7
feat: add normalize_fparam option to fitting net and ignore tests dir…
anyangml Mar 30, 2026
3161398
chore: ignore tests directory in git tracking
anyangml Mar 30, 2026
ed8a87c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2026
73398f6
feat: add intensity_norm option to XAS loss for scale-invariant train…
anyangml Mar 31, 2026
eaae746
Merge branch 'feat/support-xas-spectrum' of github.com:anyangml/deepm…
anyangml Mar 31, 2026
a663c33
feat: add per-type/edge energy standard deviation normalization to XA…
anyangml Apr 1, 2026
f2d37ed
refactor: normalize energy predictions using global standard deviatio…
anyangml Apr 1, 2026
da895d0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2026
94d2a5a
fix: change XAS loss reduction from mean to sum for atomic contributions
anyangml Apr 1, 2026
5f15806
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2026
b427dda
fix: use global stat
anyangml Apr 22, 2026
3080b36
fix: seltype in inference
anyangml Apr 22, 2026
d213c7b
fix dptest
anyangml Apr 22, 2026
ddd0610
fix: finetune
anyangml Apr 23, 2026
8e17e79
chore: refactor
anyangml Apr 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@ frozen_model.*

# Test system directories
system/
tests/
3 changes: 3 additions & 0 deletions deepmd/__about__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
# Auto-generated stub for development use
__version__ = "dev"
15 changes: 8 additions & 7 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,14 @@ def get_standard_model(data: dict) -> EnergyModel:
else:
raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")

model = modelcls(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
model_kwargs: dict = {
"descriptor": descriptor,
"fitting": fitting,
"type_map": data["type_map"],
"atom_exclude_types": atom_exclude_types,
"pair_exclude_types": pair_exclude_types,
}
model = modelcls(**model_kwargs)
return model


Expand Down
70 changes: 67 additions & 3 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,13 +887,19 @@ def test_property(
high_prec=True,
)

is_xas = var_name == "xas"

if dp.get_dim_fparam() > 0:
data.add(
"fparam", dp.get_dim_fparam(), atomic=False, must=True, high_prec=False
)
if dp.get_dim_aparam() > 0:
data.add("aparam", dp.get_dim_aparam(), atomic=True, must=True, high_prec=False)

# XAS requires sel_type.npy (per-frame absorbing element type index)
if is_xas:
data.add("sel_type", 1, atomic=False, must=True, high_prec=False)

test_data = data.get_test()
mixed_type = data.mixed_type
natoms = len(test_data["type"][0])
Expand All @@ -918,21 +924,79 @@ def test_property(
else:
aparam = None

# XAS: per-atom outputs are needed to average over absorbing-element atoms
eval_atomic = has_atom_property or is_xas
ret = dp.eval(
coord,
box,
atype,
fparam=fparam,
aparam=aparam,
atomic=has_atom_property,
atomic=eval_atomic,
mixed_type=mixed_type,
)

property = ret[0]
if is_xas:
# ret[1]: per-atom property [numb_test, natoms, task_dim]
atom_prop = ret[1].reshape([numb_test, natoms, dp.task_dim])
if mixed_type:
atype_frames = atype # [numb_test, natoms]
else:
atype_frames = np.tile(atype, (numb_test, 1)) # [numb_test, natoms]
sel_type_int = test_data["sel_type"][:numb_test, 0].astype(int)
property = np.zeros([numb_test, dp.task_dim], dtype=atom_prop.dtype)
for i in range(numb_test):
t = sel_type_int[i]
mask = atype_frames[i] == t # [natoms]
property[i] = atom_prop[i][mask].sum(axis=0) # sum, not mean

# Add back the per-(type, edge) energy reference so output is in
# absolute eV (matching label format). xas_e_ref is saved in the
# model checkpoint by XASLoss.compute_output_stats.
try:
# dp is DeepProperty (wrapper); the PT backend is dp.deep_eval,
# and its ModelWrapper is dp.deep_eval.dp.
xas_e_ref = dp.deep_eval.dp.model["Default"].atomic_model.xas_e_ref
except AttributeError:
xas_e_ref = None
if xas_e_ref is not None and fparam is not None:
import torch as _torch

edge_idx_all = (
_torch.tensor(fparam.reshape(numb_test, -1)).argmax(dim=-1).numpy()
)
e_ref_np = xas_e_ref.cpu().numpy() # [ntypes, nfparam, 2]
for i in range(numb_test):
t = sel_type_int[i]
e = int(edge_idx_all[i])
property[i, :2] += e_ref_np[t, e]

# Restore intensity dims: pred_abs = pred * intensity_std + intensity_ref
try:
am = dp.deep_eval.dp.model["Default"].atomic_model
xas_intensity_ref = getattr(am, "xas_intensity_ref", None)
xas_intensity_std = getattr(am, "xas_intensity_std", None)
except AttributeError:
xas_intensity_ref = None
xas_intensity_std = None
if xas_intensity_ref is not None and xas_intensity_std is not None and fparam is not None:
import torch as _torch

edge_idx_all = (
_torch.tensor(fparam.reshape(numb_test, -1)).argmax(dim=-1).numpy()
)
int_ref_np = xas_intensity_ref.cpu().numpy() # [ntypes, nfparam, n_pts]
int_std_np = xas_intensity_std.cpu().numpy() # [ntypes, nfparam, n_pts]
for i in range(numb_test):
t = sel_type_int[i]
e = int(edge_idx_all[i])
property[i, 2:] = property[i, 2:] * int_std_np[t, e] + int_ref_np[t, e]
else:
property = ret[0]

property = property.reshape([numb_test, dp.task_dim])

if has_atom_property:
if has_atom_property and not is_xas:
aproperty = ret[1]
aproperty = aproperty.reshape([numb_test, natoms * dp.task_dim])

Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ def __init__(
if not self.input_param.get("hessian_mode") and not no_jit:
model = torch.jit.script(model)
self.dp = ModelWrapper(model)
# Filter out loss-related keys that may be present in old training checkpoints.
# This is for backward compatibility with checkpoints saved before the
# XASLoss refactor that removed persistent buffers from the loss module.
state_dict = {
k: v for k, v in state_dict.items() if not k.startswith("loss.")
}
self.dp.load_state_dict(state_dict)
elif str(self.model_path).endswith(".pth"):
extra_files = {"data_modifier.pth": ""}
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ def __init__(
self.wrapper = ModelWrapper(self.model) # inference only
if JIT:
self.wrapper = torch.jit.script(self.wrapper)
# Drop loss-related keys (e.g. loss buffers like XASLoss.e_ref) that
# are not part of the inference-only wrapper.
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("loss.")}
self.wrapper.load_state_dict(state_dict)
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from .tensor import (
TensorLoss,
)
from .xas import (
XASLoss,
)

__all__ = [
"DOSLoss",
Expand All @@ -31,4 +34,5 @@
"PropertyLoss",
"TaskLoss",
"TensorLoss",
"XASLoss",
]
191 changes: 191 additions & 0 deletions deepmd/pt/loss/xas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
Any,
)

import torch
import torch.nn.functional as F

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)

log = logging.getLogger(__name__)


class XASLoss(TaskLoss):
"""Loss for XAS spectrum fitting via property fitting + sel_type reduction.

The model outputs per-atom property vectors (atom_xas). For each frame
this loss sums the contributions of atoms matching ``sel_type`` (read from
``sel_type.npy`` per system) and computes a loss against the per-frame XAS
label.

Normalisation statistics (``xas_e_ref``, ``xas_intensity_ref/std``,
``out_bias``, ``out_std``) are computed once before training by
:meth:`DPXASAtomicModel.compute_or_load_out_stat` via the standard
:meth:`compute_or_load_stat` pipeline and stored as model buffers.

Parameters
----------
task_dim : int
Output dimension of the fitting net (e.g. 102 = E_min + E_max + 100 pts).
nfparam : int
Length of the fparam one-hot vector (= number of edge types).
var_name : str
Property name, must match ``property_name`` in the fitting config.
loss_func : str
One of ``smooth_mae``, ``mae``, ``mse``, ``rmse``.
metric : list[str]
Metrics to display during training (absolute scale).
beta : float
Beta parameter for smooth_l1 loss.
pref_energy : float
Weight multiplier for the two energy dimensions (E_min, E_max).
pref_spectrum : float
Weight multiplier for the intensity dimensions (index 2 onward).
smooth_reg : float
Coefficient of the second-order smoothness regulariser applied to the
predicted intensity dimensions in standardised space. 0.0 disables (default).
"""

def __init__(
self,
task_dim: int,
nfparam: int,
var_name: str = "xas",
loss_func: str = "smooth_mae",
metric: list[str] = ["mae"],
beta: float = 1.0,
pref_energy: float = 1.0,
pref_spectrum: float = 1.0,
smooth_reg: float = 0.0,
**kwargs: Any,
) -> None:
super().__init__()
self.task_dim = task_dim
self.nfparam = nfparam
self.var_name = var_name
self.loss_func = loss_func
self.metric = metric
self.beta = beta
self.pref_energy = pref_energy
self.pref_spectrum = pref_spectrum
self.smooth_reg = smooth_reg

def forward(
self,
input_dict: dict[str, torch.Tensor],
model: torch.nn.Module,
label: dict[str, torch.Tensor],
natoms: int,
learning_rate: float = 0.0,
mae: bool = False,
) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]:
model_pred = model(**input_dict)

# per-atom outputs: [nf, nloc, task_dim]
atom_prop = model_pred[f"atom_{self.var_name}"]
atype = input_dict["atype"] # [nf, nloc]

sel_type = label["sel_type"][:, 0].long() # [nf]

nf, nloc, td = atom_prop.shape
mask_3d = atype.unsqueeze(-1) == sel_type.view(nf, 1, 1) # [nf, nloc, 1]
pred = (atom_prop * mask_3d).sum(dim=1) # [nf, task_dim]

label_xas = label[self.var_name] # [nf, task_dim]

# --- per-(type, edge) stat lookup from model buffers ---
fparam = input_dict.get("fparam")
if fparam is not None and fparam.numel() > 0:
edge_idx = fparam.reshape(nf, -1).argmax(dim=-1).clamp(0, self.nfparam - 1)
else:
edge_idx = torch.zeros(nf, dtype=torch.long, device=pred.device)

am = model.atomic_model
e_ref = am.xas_e_ref # [ntypes, nfparam, 2]
intensity_ref = am.xas_intensity_ref # [ntypes, nfparam, n_pts]
intensity_std = am.xas_intensity_std # [ntypes, nfparam, n_pts]

_dev = e_ref.device
_sel = sel_type.to(_dev)
_eidx = edge_idx.to(_dev)

e_ref_frame = e_ref[_sel, _eidx].to(pred.device) # [nf, 2]
intensity_ref_frame = intensity_ref[_sel, _eidx].to(pred.device) # [nf, n_pts]
intensity_std_frame = intensity_std[_sel, _eidx].to(pred.device) # [nf, n_pts]

# Normalised targets:
# energy dims → chemical shift: label - e_ref
# intensity dims → standardised: (label - ref) / std
label_energy_norm = label_xas[:, :2] - e_ref_frame
label_intens_norm = (label_xas[:, 2:] - intensity_ref_frame) / intensity_std_frame

def _elem_loss(p: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if self.loss_func == "smooth_mae":
return F.smooth_l1_loss(p, t, reduction="sum", beta=self.beta)
elif self.loss_func == "mae":
return F.l1_loss(p, t, reduction="sum")
elif self.loss_func == "mse":
return F.mse_loss(p, t, reduction="sum")
elif self.loss_func == "rmse":
return torch.sqrt(F.mse_loss(p, t, reduction="mean"))
else:
raise RuntimeError(f"Unknown loss function: {self.loss_func}")

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
loss += self.pref_energy * _elem_loss(pred[:, :2], label_energy_norm)
loss += self.pref_spectrum * _elem_loss(pred[:, 2:], label_intens_norm)

# Smoothness regulariser on standardised intensity dims (scale-invariant).
n_pts = self.task_dim - 2
if self.smooth_reg > 0.0 and n_pts >= 3:
pi = pred[:, 2:] # [nf, n_pts] in standardised space
curv = pi[:, 2:] - 2.0 * pi[:, 1:-1] + pi[:, :-2]
loss += self.smooth_reg * (curv**2).mean()

# --- metrics (reported on absolute scale) ---
pred_abs = pred.clone()
pred_abs[:, :2] = pred[:, :2] + e_ref_frame
pred_abs[:, 2:] = pred[:, 2:] * intensity_std_frame + intensity_ref_frame

more_loss: dict[str, torch.Tensor] = {}
if "mae" in self.metric:
more_loss["mae"] = F.l1_loss(
pred_abs, label_xas, reduction="mean"
).detach()
if "rmse" in self.metric:
more_loss["rmse"] = torch.sqrt(
F.mse_loss(pred_abs, label_xas, reduction="mean")
).detach()

model_pred[self.var_name] = pred_abs
return model_pred, loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Declare required data files: xas label + sel_type."""
return [
DataRequirementItem(
self.var_name,
ndof=self.task_dim,
atomic=False,
must=True,
high_prec=True,
),
DataRequirementItem(
"sel_type",
ndof=1,
atomic=False,
must=True,
high_prec=False,
),
]
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from .property_atomic_model import (
DPPropertyAtomicModel,
DPXASAtomicModel,
)

__all__ = [
Expand All @@ -51,6 +52,7 @@
"DPEnergyAtomicModel",
"DPPolarAtomicModel",
"DPPropertyAtomicModel",
"DPXASAtomicModel",
"DPZBLLinearEnergyAtomicModel",
"LinearEnergyAtomicModel",
"PairTabAtomicModel",
Expand Down
Loading