Skip to content

Commit f39a081

Browse files
iProzdanyangml
andauthored
feat(dpa3): decouple charge_spin from fparam (#5431)
`add_chg_spin_ebd=True` previously hijacked `fparam` to smuggle the [charge, spin] scalars into DPA3, forcing users to set `numb_fparam=2` on the fitting net and blocking real frame parameters from coexisting with charge/spin. This PR plumbs `charge_spin: Tensor | None` as a first-class kwarg through every forward chain and adds an optional `default_chg_spin` fallback on the DPA3 descriptor. Backends covered: pt, dpmodel, pt_expt. The pd backend is left untouched. The C/C++/LAMMPS layer is unchanged. ## Forward chain `Calculator / deep_eval / dp test / lmdb_data / training.get_data` -> `wrapper.forward` -> `ener_model.forward / forward_lower` -> `make_model.forward_common / forward_common_lower` -> `base_atomic_model.forward_common_atomic` -> `dp_atomic_model.forward_atomic` # default_chg_spin fallback here -> `descriptor.forward` # only DPA3 consumes it All other descriptors (se_e2_a, se_r, se_t, se_t_tebd, dpa1, dpa2, hybrid) only forward the kwarg through their signatures. ## New API surface On `BaseAtomicModel` and the wrapped model: - `has_chg_spin_ebd() -> bool` - `get_dim_chg_spin() -> int` # 2 for DPA3, else 0 - `has_default_chg_spin() -> bool` - `get_default_chg_spin() -> list[float] | Tensor | None` DPA3 descriptor gains a `default_chg_spin: list[float] | None = None` constructor arg (length 2, validated; round-trips through `serialize`). `descrpt_dpa3_args` exposes the matching `Argument` and the `add_chg_spin_ebd` doc no longer references fparam. ## Training data `charge_spin` is registered as a `DataRequirementItem(ndof=2, atomic=False, must=not has_default_cs, default=cs_default)`. The `get_data` path drops it (along with fparam) on frames where `find_charge_spin == 0`, so missing per-frame data falls back to `default_chg_spin` when one is configured. ## pt_expt specifics `forward_common_atomic`, `forward_common_lower_exportable`, the `make_fx`-traced inner `fn`, `_trace_and_compile`, and all wrapping energy/spin/dipole/dos/polar/property/dp_linear/dp_zbl model variants gained a `charge_spin` arg in lockstep so the export and inductor- compiled paths keep matching signatures. `deep_eval` no longer reuses `fparam` for charge/spin — it constructs `charge_spin_t` (with the metadata default-fallback) and passes it explicitly. ## Tests Three `cs_mode` cases are exercised everywhere it matters: `no_chg_spin`, `explicit_chg_spin`, `default_chg_spin`. - pt UT (`source/tests/pt/model/test_dpa3.py::test_consistency`) rewritten over the three modes; default mode also asserts that the default-fallback descriptor matches an explicit `[5,1]` peer. - pt_expt UT (`source/tests/pt_expt/descriptor/test_dpa3.py`) gains `test_consistency_chg_spin` covering explicit and default modes against dpmodel. - Universal tests: `DescriptorParamDPA3` learns `default_chg_spin`, parametrize gains `(None, [5.0, 1.0])`, and the `add_chg_spin_ebd` skip rule in `test_model.py` is replaced — the universal driver does not feed `charge_spin`, so chg_spin runs rely on the `default_chg_spin` fallback. 622 DPA3 model cases pass. - Consistent tests: `descriptor/common.py` threads `charge_spin` through every `eval_*` (pd ignores it). `test_dpa3.py` swaps `self.fparam` for `self.charge_spin`. `test_ener.py:: TestEnerChgSpinEbdFparam` is reparametrized over the three modes and no longer touches `numb_fparam` / `default_fparam`. ## Smoke `examples/water/dpa3 dp --pt train input_torch_dynamic.json --skip-neighbor-stat` runs to batch 600 with monotonically decreasing loss. ## Test plan - [x] pytest source/tests/pt/model/test_dpa3.py -v - [x] pytest source/tests/pt_expt/descriptor/test_dpa3.py -v - [x] pytest source/tests/consistent/descriptor/test_dpa3.py -v - [x] pytest source/tests/consistent/model/test_ener.py::TestEnerChgSpinEbdFparam -v - [x] pytest source/tests/universal/dpmodel/model/test_model.py -k "DPA3 and 5" - [x] examples/water/dpa3 smoke training (600 batches, loss decreasing) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Optional per-frame charge-spin input supported end-to-end: data readers, batching, inference, training, export/tracing; prediction/training calls accept and forward it. * Models/descriptors expose capability-query and default-value helpers for charge-spin embeddings; exportable/traced APIs honor defaults. * **Tests** * Tests updated/expanded to validate charge-spin embedding behavior and cross-backend consistency. * **Chores** * Configuration normalization warns on legacy charge/spin packed into legacy parameters and documents migration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
1 parent 9245a7b commit f39a081

101 files changed

Lines changed: 1642 additions & 249 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

deepmd/calculator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,14 @@ def calculate(
137137

138138
fparam = self.atoms.info.get("fparam", None)
139139
aparam = self.atoms.info.get("aparam", None)
140+
charge_spin = self.atoms.info.get("charge_spin", None)
140141
e, f, v = self.dp.eval(
141-
coords=coord, cells=cell, atom_types=atype, fparam=fparam, aparam=aparam
142+
coords=coord,
143+
cells=cell,
144+
atom_types=atype,
145+
fparam=fparam,
146+
aparam=aparam,
147+
charge_spin=charge_spin,
142148
)[:3]
143149
self.results["energy"] = e[0][0]
144150
# see https://gitlab.com/ase/ase/-/merge_requests/2485

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,22 @@ def get_default_fparam(self) -> list[float] | None:
156156
"""Get the default frame parameters."""
157157
return None
158158

159+
def has_chg_spin_ebd(self) -> bool:
160+
"""Check if the model has charge spin embedding."""
161+
return False
162+
163+
def get_dim_chg_spin(self) -> int:
164+
"""Get the dimension of charge_spin input."""
165+
return 0
166+
167+
def has_default_chg_spin(self) -> bool:
168+
"""Check if the model has default charge_spin values."""
169+
return False
170+
171+
def get_default_chg_spin(self) -> list[float] | None:
172+
"""Get the default charge_spin values."""
173+
return None
174+
159175
def reinit_atom_exclude(
160176
self,
161177
exclude_types: list[int] = [],
@@ -232,6 +248,7 @@ def forward_common_atomic(
232248
fparam: Array | None = None,
233249
aparam: Array | None = None,
234250
comm_dict: dict | None = None,
251+
charge_spin: Array | None = None,
235252
) -> dict[str, Array]:
236253
"""Common interface for atomic inference.
237254
@@ -284,6 +301,7 @@ def forward_common_atomic(
284301
fparam=fparam,
285302
aparam=aparam,
286303
comm_dict=comm_dict,
304+
charge_spin=charge_spin,
287305
)
288306
ret_dict = self.apply_out_stat(ret_dict, atype)
289307

@@ -312,6 +330,7 @@ def call(
312330
mapping: Array | None = None,
313331
fparam: Array | None = None,
314332
aparam: Array | None = None,
333+
charge_spin: Array | None = None,
315334
) -> dict[str, Array]:
316335
return self.forward_common_atomic(
317336
extended_coord,
@@ -320,6 +339,7 @@ def call(
320339
mapping=mapping,
321340
fparam=fparam,
322341
aparam=aparam,
342+
charge_spin=charge_spin,
323343
)
324344

325345
def get_intensive(self) -> bool:
@@ -524,6 +544,7 @@ def model_forward(
524544
box: np.ndarray | None,
525545
fparam: np.ndarray | None = None,
526546
aparam: np.ndarray | None = None,
547+
charge_spin: np.ndarray | None = None,
527548
) -> dict[str, np.ndarray]:
528549
# Get reference array to determine the target array type and device
529550
# Use out_bias as reference since it's always present
@@ -543,6 +564,8 @@ def model_forward(
543564
fparam = xp.asarray(fparam, device=device)
544565
if aparam is not None:
545566
aparam = xp.asarray(aparam, device=device)
567+
if charge_spin is not None:
568+
charge_spin = xp.asarray(charge_spin, device=device)
546569

547570
(
548571
extended_coord,
@@ -564,6 +587,7 @@ def model_forward(
564587
mapping=mapping,
565588
fparam=fparam,
566589
aparam=aparam,
590+
charge_spin=charge_spin,
567591
)
568592
# Convert outputs back to numpy arrays
569593
return {kk: to_numpy_array(vv) for kk, vv in atomic_ret.items()}

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,28 @@ def __init__(
7979
)
8080
super().init_out_stat()
8181

82+
def has_chg_spin_ebd(self) -> bool:
83+
"""Check if the model has charge spin embedding."""
84+
return self.add_chg_spin_ebd
85+
86+
def get_dim_chg_spin(self) -> int:
87+
"""Get the dimension of charge_spin input."""
88+
if self.add_chg_spin_ebd:
89+
return self.descriptor.get_dim_chg_spin()
90+
return 0
91+
92+
def has_default_chg_spin(self) -> bool:
93+
"""Check if the model has default charge_spin values."""
94+
if self.add_chg_spin_ebd:
95+
return self.descriptor.has_default_chg_spin()
96+
return False
97+
98+
def get_default_chg_spin(self) -> list[float] | None:
99+
"""Get the default charge_spin values."""
100+
if self.add_chg_spin_ebd and self.descriptor.has_default_chg_spin():
101+
return self.descriptor.get_default_chg_spin()
102+
return None
103+
82104
def fitting_output_def(self) -> FittingOutputDef:
83105
"""Get the output def of the fitting net."""
84106
return self.fitting_net.output_def()
@@ -158,6 +180,7 @@ def forward_atomic(
158180
fparam: Array | None = None,
159181
aparam: Array | None = None,
160182
comm_dict: dict | None = None,
183+
charge_spin: Array | None = None,
161184
) -> dict[str, Array]:
162185
"""Models' atomic predictions.
163186
@@ -178,6 +201,8 @@ def forward_atomic(
178201
comm_dict
179202
MPI communication metadata for parallel inference. ``None`` for
180203
non-parallel inference (default). Forwarded to the descriptor.
204+
charge_spin
205+
charge and spin parameter for descriptor. nf x 2
181206
182207
Returns
183208
-------
@@ -188,38 +213,29 @@ def forward_atomic(
188213
nframes, nloc, nnei = nlist.shape
189214
atype = xp_take_first_n(extended_atype, 1, nloc)
190215

191-
# Handle default fparam if fitting net supports it
192-
if (
193-
hasattr(self.fitting_net, "get_dim_fparam")
194-
and self.fitting_net.get_dim_fparam() > 0
195-
and fparam is None
196-
):
197-
# use default fparam
198-
from deepmd.dpmodel.array_api import (
199-
array_api_compat,
200-
)
201-
202-
default_fparam = self.fitting_net.get_default_fparam()
203-
assert default_fparam is not None
204-
xp = array_api_compat.array_namespace(extended_coord)
205-
default_fparam_array = xp.asarray(
206-
default_fparam,
207-
dtype=extended_coord.dtype,
208-
device=array_api_compat.device(extended_coord),
209-
)
210-
fparam_input_for_des = xp.tile(
211-
xp.reshape(default_fparam_array, (1, -1)), (nframes, 1)
212-
)
213-
else:
214-
fparam_input_for_des = fparam
216+
# Handle default charge_spin if descriptor supports it
217+
if self.add_chg_spin_ebd and charge_spin is None:
218+
default_cs = self.descriptor.get_default_chg_spin()
219+
if default_cs is not None:
220+
from deepmd.dpmodel.array_api import (
221+
array_api_compat,
222+
)
223+
224+
xp = array_api_compat.array_namespace(extended_coord)
225+
cs_array = xp.asarray(
226+
default_cs,
227+
dtype=extended_coord.dtype,
228+
device=array_api_compat.device(extended_coord),
229+
)
230+
charge_spin = xp.tile(xp.reshape(cs_array, (1, -1)), (nframes, 1))
215231

216232
descriptor, rot_mat, g2, h2, sw = self.descriptor(
217233
extended_coord,
218234
extended_atype,
219235
nlist,
220236
mapping=mapping,
221-
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
222237
comm_dict=comm_dict,
238+
charge_spin=charge_spin if self.add_chg_spin_ebd else None,
223239
)
224240
ret = self.fitting_net(
225241
descriptor,

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def forward_atomic(
225225
fparam: Array | None = None,
226226
aparam: Array | None = None,
227227
comm_dict: dict | None = None,
228+
charge_spin: Array | None = None,
228229
) -> dict[str, Array]:
229230
"""Return atomic prediction.
230231
@@ -286,6 +287,7 @@ def forward_atomic(
286287
fparam,
287288
aparam,
288289
comm_dict,
290+
charge_spin=charge_spin,
289291
)["energy"]
290292
)
291293
weights = self._compute_weight(extended_coord, extended_atype, nlists_)

deepmd/dpmodel/atomic_model/make_base_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def fwd(
138138
mapping: t_tensor | None = None,
139139
fparam: t_tensor | None = None,
140140
aparam: t_tensor | None = None,
141+
charge_spin: t_tensor | None = None,
141142
) -> dict[str, t_tensor]:
142143
pass
143144

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def forward_atomic(
254254
fparam: Array | None = None,
255255
aparam: Array | None = None,
256256
comm_dict: dict | None = None,
257+
charge_spin: Array | None = None,
257258
) -> dict[str, Array]:
258259
del comm_dict # pairtab is local; no MPI ghost exchange needed.
259260
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ def call(
509509
mapping: Array | None = None,
510510
fparam: Array | None = None,
511511
comm_dict: dict | None = None,
512+
charge_spin: Array | None = None,
512513
) -> Array:
513514
"""Compute the descriptor.
514515

deepmd/dpmodel/descriptor/dpa2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def call(
842842
mapping: Array | None = None,
843843
fparam: Array | None = None,
844844
comm_dict: dict | None = None,
845+
charge_spin: Array | None = None,
845846
) -> tuple[Array, Array, Array, Array, Array]:
846847
"""Compute the descriptor.
847848

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def __init__(
377377
use_loc_mapping: bool = True,
378378
type_map: list[str] | None = None,
379379
add_chg_spin_ebd: bool = False,
380+
default_chg_spin: list[float] | None = None,
380381
) -> None:
381382
super().__init__()
382383

@@ -433,6 +434,11 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
433434

434435
self.use_econf_tebd = use_econf_tebd
435436
self.add_chg_spin_ebd = add_chg_spin_ebd
437+
if default_chg_spin is not None and len(default_chg_spin) != 2:
438+
raise ValueError(
439+
"default_chg_spin must have exactly 2 values [charge, spin]"
440+
)
441+
self.default_chg_spin = default_chg_spin
436442
self.use_tebd_bias = use_tebd_bias
437443
self.use_loc_mapping = use_loc_mapping
438444
self.type_map = type_map
@@ -499,6 +505,18 @@ def get_rcut(self) -> float:
499505
"""Returns the cut-off radius."""
500506
return self.rcut
501507

508+
def get_dim_chg_spin(self) -> int:
509+
"""Returns the dimension of charge_spin input."""
510+
return 2 if self.add_chg_spin_ebd else 0
511+
512+
def has_default_chg_spin(self) -> bool:
513+
"""Returns whether default charge_spin values are set."""
514+
return self.default_chg_spin is not None
515+
516+
def get_default_chg_spin(self) -> list[float] | None:
517+
"""Returns the default charge_spin values."""
518+
return self.default_chg_spin
519+
502520
def get_rcut_smth(self) -> float:
503521
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
504522
return self.rcut_smth
@@ -647,6 +665,7 @@ def call(
647665
mapping: Array | None = None,
648666
fparam: Array | None = None,
649667
comm_dict: dict | None = None,
668+
charge_spin: Array | None = None,
650669
) -> tuple[Array, Array, Array, Array, Array]:
651670
"""Compute the descriptor.
652671
@@ -702,13 +721,13 @@ def call(
702721
)
703722

704723
if self.add_chg_spin_ebd:
705-
assert fparam is not None
724+
assert charge_spin is not None
706725
assert self.chg_embedding is not None
707726
assert self.spin_embedding is not None
708727
chg_tebd = self.chg_embedding.call()
709728
spin_tebd = self.spin_embedding.call()
710-
charge = xp.astype(fparam[:, 0], xp.int64) + 100
711-
spin = xp.astype(fparam[:, 1], xp.int64)
729+
charge = xp.astype(charge_spin[:, 0], xp.int64) + 100
730+
spin = xp.astype(charge_spin[:, 1], xp.int64)
712731
chg_ebd = xp.reshape(
713732
xp.take(chg_tebd, xp.reshape(charge, (-1,)), axis=0),
714733
(nframes, self.tebd_dim),
@@ -753,6 +772,7 @@ def serialize(self) -> dict:
753772
"use_tebd_bias": self.use_tebd_bias,
754773
"use_loc_mapping": self.use_loc_mapping,
755774
"add_chg_spin_ebd": self.add_chg_spin_ebd,
775+
"default_chg_spin": self.default_chg_spin,
756776
"type_map": self.type_map,
757777
"type_embedding": self.type_embedding.serialize(),
758778
}

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,40 @@ def get_rcut(self) -> float:
123123
"""Returns the cut-off radius."""
124124
return np.max([descrpt.get_rcut() for descrpt in self.descrpt_list]).item()
125125

126+
def get_dim_chg_spin(self) -> int:
127+
"""Returns the dimension of charge_spin input (0 if not supported)."""
128+
return max(
129+
(descrpt.get_dim_chg_spin() for descrpt in self.descrpt_list), default=0
130+
)
131+
132+
def has_default_chg_spin(self) -> bool:
133+
"""Returns whether the descriptor has a default charge_spin value."""
134+
default_chg_spin = None
135+
found_chg_spin = False
136+
for descrpt in self.descrpt_list:
137+
if descrpt.get_dim_chg_spin() == 0:
138+
continue
139+
found_chg_spin = True
140+
if not descrpt.has_default_chg_spin():
141+
return False
142+
child_default_chg_spin = descrpt.get_default_chg_spin()
143+
if child_default_chg_spin is None:
144+
return False
145+
if default_chg_spin is None:
146+
default_chg_spin = child_default_chg_spin
147+
elif child_default_chg_spin != default_chg_spin:
148+
return False
149+
return found_chg_spin
150+
151+
def get_default_chg_spin(self) -> list[float] | None:
152+
"""Returns the default charge_spin value, or None."""
153+
if not self.has_default_chg_spin():
154+
return None
155+
for descrpt in self.descrpt_list:
156+
if descrpt.get_dim_chg_spin() > 0:
157+
return descrpt.get_default_chg_spin()
158+
return None
159+
126160
def get_rcut_smth(self) -> float:
127161
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
128162
# may not be a good idea...
@@ -287,6 +321,7 @@ def call(
287321
mapping: Array | None = None,
288322
fparam: Array | None = None,
289323
comm_dict: dict | None = None,
324+
charge_spin: Array | None = None,
290325
) -> tuple[
291326
Array,
292327
Array | None,
@@ -344,7 +379,13 @@ def call(
344379
assert nl_distinguish_types is not None
345380
nl = nl_distinguish_types[:, :, nci]
346381
odescriptor, gr, _g2, _h2, _sw = descrpt(
347-
coord_ext, atype_ext, nl, mapping, comm_dict=comm_dict
382+
coord_ext,
383+
atype_ext,
384+
nl,
385+
mapping,
386+
fparam=fparam,
387+
comm_dict=comm_dict,
388+
charge_spin=charge_spin,
348389
)
349390
out_descriptor.append(odescriptor)
350391
if gr is not None:

0 commit comments

Comments
 (0)