Commit f39a081
feat(dpa3): decouple charge_spin from fparam (#5431)
`add_chg_spin_ebd=True` previously hijacked `fparam` to smuggle the
[charge, spin] scalars into DPA3, forcing users to set `numb_fparam=2`
on the fitting net and blocking real frame parameters from coexisting
with charge/spin. This PR plumbs `charge_spin: Tensor | None` as a
first-class kwarg through every forward chain and adds an optional
`default_chg_spin` fallback on the DPA3 descriptor.
Backends covered: pt, dpmodel, pt_expt. The pd backend is left
untouched.
The C/C++/LAMMPS layer is unchanged.
## Forward chain
`Calculator / deep_eval / dp test / lmdb_data / training.get_data`
-> `wrapper.forward`
-> `ener_model.forward / forward_lower`
-> `make_model.forward_common / forward_common_lower`
-> `base_atomic_model.forward_common_atomic`
-> `dp_atomic_model.forward_atomic` # default_chg_spin fallback here
-> `descriptor.forward` # only DPA3 consumes it
All other descriptors (se_e2_a, se_r, se_t, se_t_tebd, dpa1, dpa2,
hybrid) only forward the kwarg through their signatures.
## New API surface
On `BaseAtomicModel` and the wrapped model:
- `has_chg_spin_ebd() -> bool`
- `get_dim_chg_spin() -> int` # 2 for DPA3, else 0
- `has_default_chg_spin() -> bool`
- `get_default_chg_spin() -> list[float] | Tensor | None`
DPA3 descriptor gains a `default_chg_spin: list[float] | None = None`
constructor arg (length 2, validated; round-trips through `serialize`).
`descrpt_dpa3_args` exposes the matching `Argument` and the
`add_chg_spin_ebd` doc no longer references fparam.
## Training data
`charge_spin` is registered as a `DataRequirementItem(ndof=2,
atomic=False, must=not has_default_cs, default=cs_default)`. The
`get_data` path drops it (along with fparam) on frames where
`find_charge_spin == 0`, so missing per-frame data falls back to
`default_chg_spin` when one is configured.
## pt_expt specifics
`forward_common_atomic`, `forward_common_lower_exportable`, the
`make_fx`-traced inner `fn`, `_trace_and_compile`, and all wrapping
energy/spin/dipole/dos/polar/property/dp_linear/dp_zbl model variants
gained a `charge_spin` arg in lockstep so the export and inductor-
compiled paths keep matching signatures. `deep_eval` no longer reuses
`fparam` for charge/spin — it constructs `charge_spin_t` (with the
metadata default-fallback) and passes it explicitly.
## Tests
Three `cs_mode` cases are exercised everywhere it matters:
`no_chg_spin`, `explicit_chg_spin`, `default_chg_spin`.
- pt UT (`source/tests/pt/model/test_dpa3.py::test_consistency`)
rewritten over the three modes; default mode also asserts that the
default-fallback descriptor matches an explicit `[5,1]` peer.
- pt_expt UT (`source/tests/pt_expt/descriptor/test_dpa3.py`) gains
`test_consistency_chg_spin` covering explicit and default modes
against dpmodel.
- Universal tests: `DescriptorParamDPA3` learns `default_chg_spin`,
parametrize gains `(None, [5.0, 1.0])`, and the
`add_chg_spin_ebd` skip rule in `test_model.py` is replaced —
the universal driver does not feed `charge_spin`, so chg_spin runs
rely on the `default_chg_spin` fallback. 622 DPA3 model cases pass.
- Consistent tests: `descriptor/common.py` threads `charge_spin`
through every `eval_*` (pd ignores it). `test_dpa3.py` swaps
`self.fparam` for `self.charge_spin`. `test_ener.py::
TestEnerChgSpinEbdFparam` is reparametrized over the three modes
and no longer touches `numb_fparam` / `default_fparam`.
## Smoke
`examples/water/dpa3 dp --pt train input_torch_dynamic.json
--skip-neighbor-stat` runs to batch 600 with monotonically decreasing
loss.
## Test plan
- [x] pytest source/tests/pt/model/test_dpa3.py -v
- [x] pytest source/tests/pt_expt/descriptor/test_dpa3.py -v
- [x] pytest source/tests/consistent/descriptor/test_dpa3.py -v
- [x] pytest
source/tests/consistent/model/test_ener.py::TestEnerChgSpinEbdFparam -v
- [x] pytest source/tests/universal/dpmodel/model/test_model.py -k "DPA3
and 5"
- [x] examples/water/dpa3 smoke training (600 batches, loss decreasing)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional per-frame charge-spin input supported end-to-end: data
readers, batching, inference, training, export/tracing;
prediction/training calls accept and forward it.
* Models/descriptors expose capability-query and default-value helpers
for charge-spin embeddings; exportable/traced APIs honor defaults.
* **Tests**
* Tests updated/expanded to validate charge-spin embedding behavior and
cross-backend consistency.
* **Chores**
* Configuration normalization warns on legacy charge/spin packed into
legacy parameters and documents migration.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
Co-authored-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>1 parent 9245a7b commit f39a081
101 files changed
Lines changed: 1642 additions & 249 deletions
File tree
- deepmd
- dpmodel
- atomic_model
- descriptor
- model
- utils
- entrypoints
- infer
- jax
- atomic_model
- jax2tf
- model
- pt_expt
- descriptor
- infer
- model
- train
- utils
- pt
- infer
- model
- atomic_model
- descriptor
- model
- modifier
- train
- utils
- utils
- source/tests
- consistent
- descriptor
- model
- pd/model
- pt_expt
- descriptor
- infer
- model
- pt
- model
- universal/dpmodel
- descriptor
- model
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
137 | 137 | | |
138 | 138 | | |
139 | 139 | | |
| 140 | + | |
140 | 141 | | |
141 | | - | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
142 | 148 | | |
143 | 149 | | |
144 | 150 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
156 | 156 | | |
157 | 157 | | |
158 | 158 | | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
159 | 175 | | |
160 | 176 | | |
161 | 177 | | |
| |||
232 | 248 | | |
233 | 249 | | |
234 | 250 | | |
| 251 | + | |
235 | 252 | | |
236 | 253 | | |
237 | 254 | | |
| |||
284 | 301 | | |
285 | 302 | | |
286 | 303 | | |
| 304 | + | |
287 | 305 | | |
288 | 306 | | |
289 | 307 | | |
| |||
312 | 330 | | |
313 | 331 | | |
314 | 332 | | |
| 333 | + | |
315 | 334 | | |
316 | 335 | | |
317 | 336 | | |
| |||
320 | 339 | | |
321 | 340 | | |
322 | 341 | | |
| 342 | + | |
323 | 343 | | |
324 | 344 | | |
325 | 345 | | |
| |||
524 | 544 | | |
525 | 545 | | |
526 | 546 | | |
| 547 | + | |
527 | 548 | | |
528 | 549 | | |
529 | 550 | | |
| |||
543 | 564 | | |
544 | 565 | | |
545 | 566 | | |
| 567 | + | |
| 568 | + | |
546 | 569 | | |
547 | 570 | | |
548 | 571 | | |
| |||
564 | 587 | | |
565 | 588 | | |
566 | 589 | | |
| 590 | + | |
567 | 591 | | |
568 | 592 | | |
569 | 593 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
79 | 79 | | |
80 | 80 | | |
81 | 81 | | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
82 | 104 | | |
83 | 105 | | |
84 | 106 | | |
| |||
158 | 180 | | |
159 | 181 | | |
160 | 182 | | |
| 183 | + | |
161 | 184 | | |
162 | 185 | | |
163 | 186 | | |
| |||
178 | 201 | | |
179 | 202 | | |
180 | 203 | | |
| 204 | + | |
| 205 | + | |
181 | 206 | | |
182 | 207 | | |
183 | 208 | | |
| |||
188 | 213 | | |
189 | 214 | | |
190 | 215 | | |
191 | | - | |
192 | | - | |
193 | | - | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
215 | 231 | | |
216 | 232 | | |
217 | 233 | | |
218 | 234 | | |
219 | 235 | | |
220 | 236 | | |
221 | | - | |
222 | 237 | | |
| 238 | + | |
223 | 239 | | |
224 | 240 | | |
225 | 241 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
225 | 225 | | |
226 | 226 | | |
227 | 227 | | |
| 228 | + | |
228 | 229 | | |
229 | 230 | | |
230 | 231 | | |
| |||
286 | 287 | | |
287 | 288 | | |
288 | 289 | | |
| 290 | + | |
289 | 291 | | |
290 | 292 | | |
291 | 293 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
138 | 138 | | |
139 | 139 | | |
140 | 140 | | |
| 141 | + | |
141 | 142 | | |
142 | 143 | | |
143 | 144 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
254 | 254 | | |
255 | 255 | | |
256 | 256 | | |
| 257 | + | |
257 | 258 | | |
258 | 259 | | |
259 | 260 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
509 | 509 | | |
510 | 510 | | |
511 | 511 | | |
| 512 | + | |
512 | 513 | | |
513 | 514 | | |
514 | 515 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
842 | 842 | | |
843 | 843 | | |
844 | 844 | | |
| 845 | + | |
845 | 846 | | |
846 | 847 | | |
847 | 848 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
377 | 377 | | |
378 | 378 | | |
379 | 379 | | |
| 380 | + | |
380 | 381 | | |
381 | 382 | | |
382 | 383 | | |
| |||
433 | 434 | | |
434 | 435 | | |
435 | 436 | | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
436 | 442 | | |
437 | 443 | | |
438 | 444 | | |
| |||
499 | 505 | | |
500 | 506 | | |
501 | 507 | | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
| 519 | + | |
502 | 520 | | |
503 | 521 | | |
504 | 522 | | |
| |||
647 | 665 | | |
648 | 666 | | |
649 | 667 | | |
| 668 | + | |
650 | 669 | | |
651 | 670 | | |
652 | 671 | | |
| |||
702 | 721 | | |
703 | 722 | | |
704 | 723 | | |
705 | | - | |
| 724 | + | |
706 | 725 | | |
707 | 726 | | |
708 | 727 | | |
709 | 728 | | |
710 | | - | |
711 | | - | |
| 729 | + | |
| 730 | + | |
712 | 731 | | |
713 | 732 | | |
714 | 733 | | |
| |||
753 | 772 | | |
754 | 773 | | |
755 | 774 | | |
| 775 | + | |
756 | 776 | | |
757 | 777 | | |
758 | 778 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
123 | 123 | | |
124 | 124 | | |
125 | 125 | | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
126 | 160 | | |
127 | 161 | | |
128 | 162 | | |
| |||
287 | 321 | | |
288 | 322 | | |
289 | 323 | | |
| 324 | + | |
290 | 325 | | |
291 | 326 | | |
292 | 327 | | |
| |||
344 | 379 | | |
345 | 380 | | |
346 | 381 | | |
347 | | - | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
348 | 389 | | |
349 | 390 | | |
350 | 391 | | |
| |||
0 commit comments