Skip to content

Commit 6c29e76

Browse files
authored
Merge pull request #25 from santi921/fix/timings-json-path-mismatch
Fix/timings json path mismatch
2 parents d7eae84 + 1e9db29 commit 6c29e76

10 files changed

Lines changed: 410 additions & 42 deletions

File tree

.gitignore

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,21 @@
3838
*wfx
3939
*xyz
4040

41+
# working folders for dev
42+
examples/
43+
4144
# W&B directories
4245
wandb/
46+
# ai documents
47+
docs/
48+
49+
# Always track test fixtures regardless of extension rules above
50+
!tests/test_files/
51+
!tests/test_files/*/
52+
!tests/test_files/**/
53+
!tests/test_files/**/*.inp
54+
!tests/test_files/**/*.in
55+
!tests/test_files/**/*.json
56+
!tests/test_files/**/*.txt
57+
!tests/test_files/**/*.lmdb
58+
!tests/test_files/**/*.lmdb-lock

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ full-runner-parsl-nersc = "qtaim_gen.source.scripts.full_runner_parsl_nersc:main
4848
generator-single-runner = "qtaim_gen.source.scripts.generator_run:main"
4949

5050
find-empty-compressed = "qtaim_gen.source.scripts.helpers.find_empty_compressed:main"
51+
count-orca-json = "qtaim_gen.source.scripts.helpers.count_orca_json:main"
5152
check-res-wfn = "qtaim_gen.source.scripts.helpers.check_res_wfn:main"
5253
check-res-rxn-json = "qtaim_gen.source.scripts.helpers.check_res_rxn_json:main"
5354
folder-xyz-molecules-to-pkl = "qtaim_gen.source.scripts.helpers.folder_xyz_molecules_to_pkl:main"

qtaim_gen/source/core/converter.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self, config_dict: Dict[str, Any], config_path: str = None):
131131
self.config_dict = config_dict
132132
self.config_path = config_path
133133
self.restart = config_dict["restart"]
134+
self._processed_source_keys: set = set()
134135

135136

136137
# Setup logging
@@ -169,7 +170,7 @@ def __init__(self, config_dict: Dict[str, Any], config_path: str = None):
169170
else:
170171
self.save_scaler = False
171172

172-
self.skip_keys = config_dict.get("filter_list", ["length", "scaled"])
173+
self.skip_keys = list(config_dict.get("filter_list", ["length", "scaled"])) + ["processed_source_keys"]
173174

174175
# Parallelization settings
175176
self.n_workers = config_dict.get("n_workers", 8)
@@ -249,14 +250,18 @@ def __init__(self, config_dict: Dict[str, Any], config_path: str = None):
249250
self.logger.info(f"Connected to output LMDB: {self.file}")
250251

251252
if self.restart and os.path.exists(self.file):
252-
# get all existing keys from the existing LMDB file and store in self.existing_keys to reference against
253253
with self.db.begin(write=False) as txn:
254-
self.existing_keys = set()
255-
256-
cursor = txn.cursor()
257-
for key, _ in cursor:
258-
if key.decode("ascii") not in self.skip_keys:
259-
self.existing_keys.add(key.decode("ascii"))
254+
# prefer source-key metadata written by new-format converters
255+
psk_raw = txn.get(b"processed_source_keys")
256+
if psk_raw is not None:
257+
self.existing_keys = pickle.loads(psk_raw)
258+
else:
259+
# backward compat: old-format LMDBs stored molecule IDs as keys
260+
self.existing_keys = set()
261+
cursor = txn.cursor()
262+
for key, _ in cursor:
263+
if key.decode("ascii") not in self.skip_keys:
264+
self.existing_keys.add(key.decode("ascii"))
260265

261266

262267
# handle scaled info
@@ -485,7 +490,11 @@ def scale_graphs_single(
485490
if key_str not in self.config_dict["filter_list"]:
486491
# process graph
487492
try:
488-
graph = load_graph_from_serialized(pickle.loads(value))
493+
raw = pickle.loads(value)
494+
if isinstance(raw, dict):
495+
graph = load_graph_from_serialized(raw["molecule_graph"])
496+
else:
497+
graph = load_graph_from_serialized(raw)
489498
except Exception as e:
490499
self.logger.exception(f"Failed to load graph for key {key_str}: {e}")
491500
continue
@@ -497,7 +506,7 @@ def scale_graphs_single(
497506
txn.put(
498507
f"{key_str}".encode("ascii"),
499508
pickle.dumps(
500-
serialize_graph(graph[0], ret=True), protocol=-1
509+
{"molecule_graph": serialize_graph(graph[0], ret=True)}, protocol=-1
501510
),
502511
)
503512
txn.commit()
@@ -690,9 +699,11 @@ def finalize(self, return_info=False, keys_to_iterate=None, processed_count=0):
690699
f"{lmdb_path}/label_scaler_iterative{shard_suffix}.pt"
691700
)
692701

693-
# last info on whether the graphs were scaled or not
702+
# write metadata required by qtaim_embed's LMDBBaseDataset
694703
txn = self.db.begin(write=True)
704+
txn.put("length".encode("ascii"), pickle.dumps(processed_count, protocol=-1))
695705
txn.put("scaled".encode("ascii"), pickle.dumps(False, protocol=-1))
706+
txn.put("processed_source_keys".encode("ascii"), pickle.dumps(self._processed_source_keys, protocol=-1))
696707
txn.commit()
697708
self.db.close()
698709

@@ -852,18 +863,25 @@ def merge_shards(
852863
map_async=True
853864
)
854865

855-
# Copy all entries from shards
866+
# Copy all entries from shards, re-numbering graph keys to avoid collisions
867+
_merge_skip = {b"length", b"scaled", b"scaler_finalized", b"processed_source_keys"}
856868
total_copied = 0
869+
global_idx = 0
857870
with merged_env.begin(write=True) as dst_txn:
858871
for i, lmdb_path in enumerate(shard_lmdbs):
859872
logger.info(f"Copying shard {i+1}/{len(shard_lmdbs)}")
860873
src_env = lmdb.open(lmdb_path, subdir=False, readonly=True, lock=False)
861874
with src_env.begin() as src_txn:
862875
cursor = src_txn.cursor()
863876
for key, value in cursor:
864-
dst_txn.put(key, value)
877+
if key in _merge_skip:
878+
continue
879+
dst_txn.put(f"{global_idx}".encode("ascii"), value)
880+
global_idx += 1
865881
total_copied += 1
866882
src_env.close()
883+
dst_txn.put(b"length", pickle.dumps(global_idx, protocol=-1))
884+
dst_txn.put(b"scaled", pickle.dumps(False, protocol=-1))
867885

868886
merged_env.close()
869887
logger.info(f"Merged {total_copied} entries")
@@ -919,7 +937,7 @@ def merge_shards(
919937
logger.info("Applying merged scalers to LMDB...")
920938
env = lmdb.open(output_path, subdir=False, map_size=map_size)
921939
count = 0
922-
metadata_keys = {b'scaled', b'scaler_finalized', b'length'}
940+
metadata_keys = {b'scaled', b'scaler_finalized', b'length', b'processed_source_keys'}
923941
with env.begin(write=True) as txn:
924942
cursor = txn.cursor()
925943
for key, value in cursor:
@@ -928,17 +946,20 @@ def merge_shards(
928946
continue
929947

930948
try:
931-
# Deserialize: pickle.loads returns bytes, then deserialize to PyG HeteroData
932-
serialized_bytes = pickle.loads(value)
933-
graph = load_graph_from_serialized(serialized_bytes)
949+
# Deserialize: pickle.loads may return dict or raw bytes depending on format
950+
raw = pickle.loads(value)
951+
if isinstance(raw, dict):
952+
graph = load_graph_from_serialized(raw["molecule_graph"])
953+
else:
954+
graph = load_graph_from_serialized(raw)
934955

935956
# Apply scalers - feature scaler expects a list
936957
graph = merged_feature_scaler([graph])
937958
graph = merged_label_scaler(graph)
938959

939960
# Serialize and write back
940961
serialized_bytes = serialize_graph(graph[0], ret=True)
941-
txn.put(key, pickle.dumps(serialized_bytes, protocol=-1))
962+
txn.put(key, pickle.dumps({"molecule_graph": serialized_bytes}, protocol=-1))
942963
count += 1
943964
except Exception as e:
944965
logger.warning(f"Failed to scale graph {key}: {e}")
@@ -1123,9 +1144,10 @@ def process(
11231144
self.feature_scaler_iterative.update([first_graph])
11241145
self.label_scaler_iterative.update([first_graph])
11251146
write_buffer.append((
1126-
f"{key_str}".encode("ascii"),
1127-
pickle.dumps(serialize_graph(first_graph, ret=True), protocol=-1),
1147+
f"{processed_count}".encode("ascii"),
1148+
pickle.dumps({"molecule_graph": serialize_graph(first_graph, ret=True)}, protocol=-1),
11281149
))
1150+
self._processed_source_keys.add(key_str)
11291151
processed_count += 1
11301152
first_key_idx = idx + 1
11311153
break
@@ -1162,9 +1184,10 @@ def process_key(key):
11621184
self.label_scaler_iterative.update([graph])
11631185

11641186
write_buffer.append((
1165-
f"{key_str}".encode("ascii"),
1166-
pickle.dumps(serialize_graph(graph, ret=True), protocol=-1),
1187+
f"{processed_count}".encode("ascii"),
1188+
pickle.dumps({"molecule_graph": serialize_graph(graph, ret=True)}, protocol=-1),
11671189
))
1190+
self._processed_source_keys.add(key_str)
11681191
processed_count += 1
11691192

11701193
if len(write_buffer) >= self.batch_size:
@@ -1379,9 +1402,10 @@ def process(
13791402
self.feature_scaler_iterative.update([first_graph])
13801403
self.label_scaler_iterative.update([first_graph])
13811404
write_buffer.append((
1382-
f"{key_str}".encode("ascii"),
1383-
pickle.dumps(serialize_graph(first_graph, ret=True), protocol=-1),
1405+
f"{processed_count}".encode("ascii"),
1406+
pickle.dumps({"molecule_graph": serialize_graph(first_graph, ret=True)}, protocol=-1),
13841407
))
1408+
self._processed_source_keys.add(key_str)
13851409
processed_count += 1
13861410
first_key_idx = idx + 1
13871411
break
@@ -1418,9 +1442,10 @@ def process_key(key):
14181442
self.label_scaler_iterative.update([graph])
14191443

14201444
write_buffer.append((
1421-
f"{key_str}".encode("ascii"),
1422-
pickle.dumps(serialize_graph(graph, ret=True), protocol=-1),
1445+
f"{processed_count}".encode("ascii"),
1446+
pickle.dumps({"molecule_graph": serialize_graph(graph, ret=True)}, protocol=-1),
14231447
))
1448+
self._processed_source_keys.add(key_str)
14241449
processed_count += 1
14251450

14261451
if len(write_buffer) >= self.batch_size:
@@ -2040,9 +2065,10 @@ def process(
20402065
self.feature_scaler_iterative.update([first_graph])
20412066
self.label_scaler_iterative.update([first_graph])
20422067
write_buffer.append((
2043-
f"{key_str}".encode("ascii"),
2044-
pickle.dumps(serialize_graph(first_graph, ret=True), protocol=-1),
2068+
f"{processed_count}".encode("ascii"),
2069+
pickle.dumps({"molecule_graph": serialize_graph(first_graph, ret=True)}, protocol=-1),
20452070
))
2071+
self._processed_source_keys.add(key_str)
20462072
processed_count += 1
20472073
first_key_idx = idx + 1
20482074
self.logger.info(f"Successfully initialized grapher with key {key_str}")
@@ -2087,9 +2113,10 @@ def process_key(key):
20872113
self.label_scaler_iterative.update([graph])
20882114

20892115
write_buffer.append((
2090-
f"{key_str}".encode("ascii"),
2091-
pickle.dumps(serialize_graph(graph, ret=True), protocol=-1),
2116+
f"{processed_count}".encode("ascii"),
2117+
pickle.dumps({"molecule_graph": serialize_graph(graph, ret=True)}, protocol=-1),
20922118
))
2119+
self._processed_source_keys.add(key_str)
20932120
processed_count += 1
20942121

20952122
# Batch commit

qtaim_gen/source/scripts/helpers/generator_to_embed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,16 @@ def main():
221221
)
222222

223223
# Scale using train-only fitting
224-
if not args.skip_scaling:
224+
skip_scaling = args.skip_scaling or config_dict.get("skip_scaling", False)
225+
if not skip_scaling:
225226
scale_split_lmdbs(converter, split_paths)
226227
else:
227228
print("Skipping scaling step")
228229

229230
else:
230231
# Original behavior: scale the single output LMDB
231-
if not args.skip_scaling:
232+
skip_scaling = args.skip_scaling or config_dict.get("skip_scaling", False)
233+
if not skip_scaling:
232234
converter.scale_graph_lmdb()
233235
else:
234236
print("Skipping scaling step")

qtaim_gen/source/scripts/helpers/tracking_db.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,10 @@ def find_leaf_folders(path):
117117
"t_becke_fuzzy_density",
118118
"t_becke_fuzzy_spin",
119119
"t_bond",
120-
"t_other_alie",
121-
"t_other_geometry"
120+
"t_other_alie",
121+
"t_other_geometry",
122+
"has_orca_json",
123+
"val_orca",
122124
]
123125
columns = list(set(columns)) # ensure uniqueness
124126

@@ -417,7 +419,25 @@ def get_tabs(subset):
417419
f" {subset}: {get_tabs(subset)} {count_day} / {count_hr} / {dict_one_day_full_val.get(subset, 0)} / {dict_one_hour_full_val.get(subset, 0)}"
418420
)
419421

420-
# 7. print all counts from overall counts db
422+
# 7. orca.json presence and validation per category
423+
print("---" * 30)
424+
print("orca.json presence and validation per category (subset):")
425+
c.execute(
426+
"SELECT subset, COUNT(DISTINCT job_id) FROM validation WHERE has_orca_json='True' GROUP BY subset"
427+
)
428+
orca_present = dict(c.fetchall())
429+
c.execute(
430+
"SELECT subset, COUNT(DISTINCT job_id) FROM validation WHERE val_orca='True' GROUP BY subset"
431+
)
432+
orca_valid = dict(c.fetchall())
433+
for subset in sorted(set(list(orca_present.keys()) + list(orca_valid.keys()))):
434+
present = orca_present.get(subset, 0)
435+
valid = orca_valid.get(subset, 0)
436+
total = counts_overall.get(subset, "N/A") if path_to_overall_counts_db else ""
437+
total_str = f" / {total}" if total != "" else ""
438+
print(f" {subset}: {get_tabs(subset)} has={present}{total_str} valid={valid}")
439+
440+
# 8. print all counts from overall counts db
421441
if path_to_overall_counts_db:
422442
print("---" * 30)
423443
print("Overall job counts per category (subset):")

qtaim_gen/source/utils/validation.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def get_val_breakdown_from_folder(
5858
"val_bond": None,
5959
"val_fuzzy": None,
6060
"val_other": None,
61+
"has_orca_json": False,
62+
"val_orca": None,
6163
}
6264

6365
# check timings
@@ -117,9 +119,12 @@ def get_val_breakdown_from_folder(
117119

118120
# check orca (optional)
119121
orca_file = os.path.join(folder, "orca.json")
120-
if os.path.exists(orca_file) and os.path.getsize(orca_file) > 0:
121-
tf_orca = validate_orca_dict(orca_file, n_atoms=n_atoms, logger=None)
122-
info["val_orca"] = tf_orca
122+
if os.path.exists(orca_file):
123+
info["has_orca_json"] = True
124+
if os.path.getsize(orca_file) > 0:
125+
info["val_orca"] = validate_orca_dict(orca_file, n_atoms=n_atoms, logger=None)
126+
else:
127+
info["val_orca"] = False
123128

124129
return info
125130

@@ -743,6 +748,8 @@ def get_information_from_job_folder(folder: str, full_set: int) -> dict:
743748
"val_bond": None,
744749
"val_fuzzy": None,
745750
"val_other": None,
751+
"has_orca_json": False,
752+
"val_orca": None,
746753
"n_atoms": None,
747754
"spin": None,
748755
"charge": None,

0 commit comments

Comments
 (0)