Skip to content

Commit 1053cfd

Browse files
committed
fix(dpdata): use public format APIs
Avoid importing dpdata backend parser modules directly. Use dpdata's registered public format interface for ABACUS STRU, VASP POSCAR, and SIESTA output handling so DP-GEN works with both dpdata 0.2.x and 1.x without depending on moved backend module paths. Authored by OpenClaw (model: gpt-5.5)
1 parent f4f74a2 commit 1053cfd

3 files changed

Lines changed: 25 additions & 35 deletions

File tree

dpgen/auto_test/lib/abacus.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
import dpdata
66
import numpy as np
7-
from dpdata.abacus.stru import make_unlabeled_stru
8-
from dpdata.utils import uniq_atom_names
9-
from dpdata.vasp import poscar as dpdata_poscar
107

118
import dpgen.generator.lib.abacus_scf as abacus_scf
129

@@ -142,11 +139,7 @@ def poscar2stru(poscar, inter_param, stru="STRU"):
142139
- deepks_desc: a string of deepks descriptor file
143140
- stru: output filename, usally is 'STRU'.
144141
"""
145-
# if use dpdata.System, the structure will be rotated to make cell to be lower triangular
146-
with open(poscar) as fp:
147-
lines = [line.rstrip("\n") for line in fp]
148-
stru_data = dpdata_poscar.to_system_data(lines)
149-
stru_data = uniq_atom_names(stru_data)
142+
stru_data = dpdata.System(poscar, fmt="vasp/poscar").data
150143

151144
atom_mass = []
152145
pseudo = None
@@ -185,16 +178,15 @@ def poscar2stru(poscar, inter_param, stru="STRU"):
185178
if "deepks_desc" in inter_param:
186179
deepks_desc = "./pp_orb/{}\n".format(inter_param["deepks_desc"])
187180

188-
stru_string = make_unlabeled_stru(
189-
data=stru_data,
181+
dpdata.System(data=stru_data).to(
182+
"abacus/stru",
183+
stru,
190184
frame_idx=0,
191185
pp_file=pseudo,
192186
numerical_orbital=orb,
193187
numerical_descriptor=deepks_desc,
194188
mass=atom_mass,
195189
)
196-
with open(stru, "w") as fp:
197-
fp.write(stru_string)
198190

199191

200192
def stru_fix_atom(struf, fix_atom=[True, True, True]):

dpgen/generator/lib/abacus_scf.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import copy
22
import os
33
import re
4+
import tempfile
45

6+
import dpdata
57
import numpy as np
6-
from dpdata.abacus.stru import get_frame_from_stru, make_unlabeled_stru
78

89
from dpgen.auto_test.lib import vasp
910

@@ -259,13 +260,19 @@ def make_abacus_scf_stru(
259260
if len(cells.shape) == 2:
260261
sys_data_copy["cells"] = np.array([cells])
261262
sys_data_copy["coords"] = np.array([coords])
262-
c = make_unlabeled_stru(
263-
sys_data_copy,
264-
0,
265-
pp_file=fp_pp_files,
266-
numerical_orbital=fp_orb_files,
267-
numerical_descriptor=fp_dpks_descriptor,
268-
)
263+
with tempfile.NamedTemporaryFile(
264+
mode="r+", prefix="dpgen-abacus-", suffix=".stru"
265+
) as fp:
266+
dpdata.System(data=sys_data_copy).to(
267+
"abacus/stru",
268+
fp.name,
269+
frame_idx=0,
270+
pp_file=fp_pp_files,
271+
numerical_orbital=fp_orb_files,
272+
numerical_descriptor=fp_dpks_descriptor,
273+
)
274+
fp.seek(0)
275+
c = fp.read()
269276

270277
return c
271278

@@ -302,7 +309,7 @@ def get_abacus_STRU(STRU):
302309
"dpks_descriptor": str,
303310
}
304311
"""
305-
data = get_frame_from_stru(STRU)
312+
data = dpdata.System(STRU, fmt="abacus/stru").data
306313
data["atom_masses"] = data.pop("masses")
307314
data["cells"] = data.pop("cells")[0]
308315
data["coords"] = data.pop("coords")[0]

dpgen/generator/run.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4585,9 +4585,10 @@ def post_fp_abacus_scf(iter_index, jdata):
45854585

45864586
all_sys = None
45874587
for ii, oo in zip(sys_input, sys_output):
4588-
_sys = dpdata.LabeledSystem(
4589-
oo, fmt="abacus/scf", type_map=jdata["type_map"]
4590-
)
4588+
_sys = dpdata.LabeledSystem(oo, fmt="abacus/scf")
4589+
if len(_sys) > 0:
4590+
_sys.data["atom_types"] = np.asarray(_sys.data["atom_types"], dtype=int)
4591+
_sys.apply_type_map(jdata["type_map"])
45914592
if len(_sys) > 0:
45924593
if all_sys is None:
45934594
all_sys = _sys
@@ -4626,17 +4627,7 @@ def post_fp_siesta(iter_index, jdata):
46264627
sys_output.sort()
46274628
sys_input.sort()
46284629
for idx, oo in enumerate(sys_output):
4629-
_sys = dpdata.LabeledSystem()
4630-
(
4631-
_sys.data["atom_names"],
4632-
_sys.data["atom_numbs"],
4633-
_sys.data["atom_types"],
4634-
_sys.data["cells"],
4635-
_sys.data["coords"],
4636-
_sys.data["energies"],
4637-
_sys.data["forces"],
4638-
_sys.data["virials"],
4639-
) = dpdata.siesta.output.obtain_frame(oo)
4630+
_sys = dpdata.LabeledSystem(oo, fmt="siesta/output")
46404631
if idx == 0:
46414632
all_sys = _sys
46424633
else:

0 commit comments

Comments
 (0)