@@ -131,6 +131,7 @@ def __init__(self, config_dict: Dict[str, Any], config_path: str = None):
131131 self .config_dict = config_dict
132132 self .config_path = config_path
133133 self .restart = config_dict ["restart" ]
134+ self ._processed_source_keys : set = set ()
134135
135136
136137 # Setup logging
@@ -169,7 +170,7 @@ def __init__(self, config_dict: Dict[str, Any], config_path: str = None):
169170 else :
170171 self .save_scaler = False
171172
172- self .skip_keys = config_dict .get ("filter_list" , ["length" , "scaled" ])
173+ self .skip_keys = list ( config_dict .get ("filter_list" , ["length" , "scaled" ])) + [ "processed_source_keys" ]
173174
174175 # Parallelization settings
175176 self .n_workers = config_dict .get ("n_workers" , 8 )
@@ -249,14 +250,18 @@ def __init__(self, config_dict: Dict[str, Any], config_path: str = None):
249250 self .logger .info (f"Connected to output LMDB: { self .file } " )
250251
251252 if self .restart and os .path .exists (self .file ):
252- # get all existing keys from the existing LMDB file and store in self.existing_keys to reference against
253253 with self .db .begin (write = False ) as txn :
254- self .existing_keys = set ()
255-
256- cursor = txn .cursor ()
257- for key , _ in cursor :
258- if key .decode ("ascii" ) not in self .skip_keys :
259- self .existing_keys .add (key .decode ("ascii" ))
254+ # prefer source-key metadata written by new-format converters
255+ psk_raw = txn .get (b"processed_source_keys" )
256+ if psk_raw is not None :
257+ self .existing_keys = pickle .loads (psk_raw )
258+ else :
259+ # backward compat: old-format LMDBs stored molecule IDs as keys
260+ self .existing_keys = set ()
261+ cursor = txn .cursor ()
262+ for key , _ in cursor :
263+ if key .decode ("ascii" ) not in self .skip_keys :
264+ self .existing_keys .add (key .decode ("ascii" ))
260265
261266
262267 # handle scaled info
@@ -485,7 +490,11 @@ def scale_graphs_single(
485490 if key_str not in self .config_dict ["filter_list" ]:
486491 # process graph
487492 try :
488- graph = load_graph_from_serialized (pickle .loads (value ))
493+ raw = pickle .loads (value )
494+ if isinstance (raw , dict ):
495+ graph = load_graph_from_serialized (raw ["molecule_graph" ])
496+ else :
497+ graph = load_graph_from_serialized (raw )
489498 except Exception as e :
490499 self .logger .exception (f"Failed to load graph for key { key_str } : { e } " )
491500 continue
@@ -497,7 +506,7 @@ def scale_graphs_single(
497506 txn .put (
498507 f"{ key_str } " .encode ("ascii" ),
499508 pickle .dumps (
500- serialize_graph (graph [0 ], ret = True ), protocol = - 1
509+ { "molecule_graph" : serialize_graph (graph [0 ], ret = True )} , protocol = - 1
501510 ),
502511 )
503512 txn .commit ()
@@ -690,9 +699,11 @@ def finalize(self, return_info=False, keys_to_iterate=None, processed_count=0):
690699 f"{ lmdb_path } /label_scaler_iterative{ shard_suffix } .pt"
691700 )
692701
693- # last info on whether the graphs were scaled or not
702+ # write metadata required by qtaim_embed's LMDBBaseDataset
694703 txn = self .db .begin (write = True )
704+ txn .put ("length" .encode ("ascii" ), pickle .dumps (processed_count , protocol = - 1 ))
695705 txn .put ("scaled" .encode ("ascii" ), pickle .dumps (False , protocol = - 1 ))
706+ txn .put ("processed_source_keys" .encode ("ascii" ), pickle .dumps (self ._processed_source_keys , protocol = - 1 ))
696707 txn .commit ()
697708 self .db .close ()
698709
@@ -852,18 +863,25 @@ def merge_shards(
852863 map_async = True
853864 )
854865
855- # Copy all entries from shards
866+ # Copy all entries from shards, re-numbering graph keys to avoid collisions
867+ _merge_skip = {b"length" , b"scaled" , b"scaler_finalized" , b"processed_source_keys" }
856868 total_copied = 0
869+ global_idx = 0
857870 with merged_env .begin (write = True ) as dst_txn :
858871 for i , lmdb_path in enumerate (shard_lmdbs ):
859872 logger .info (f"Copying shard { i + 1 } /{ len (shard_lmdbs )} " )
860873 src_env = lmdb .open (lmdb_path , subdir = False , readonly = True , lock = False )
861874 with src_env .begin () as src_txn :
862875 cursor = src_txn .cursor ()
863876 for key , value in cursor :
864- dst_txn .put (key , value )
877+ if key in _merge_skip :
878+ continue
879+ dst_txn .put (f"{ global_idx } " .encode ("ascii" ), value )
880+ global_idx += 1
865881 total_copied += 1
866882 src_env .close ()
883+ dst_txn .put (b"length" , pickle .dumps (global_idx , protocol = - 1 ))
884+ dst_txn .put (b"scaled" , pickle .dumps (False , protocol = - 1 ))
867885
868886 merged_env .close ()
869887 logger .info (f"Merged { total_copied } entries" )
@@ -919,7 +937,7 @@ def merge_shards(
919937 logger .info ("Applying merged scalers to LMDB..." )
920938 env = lmdb .open (output_path , subdir = False , map_size = map_size )
921939 count = 0
922- metadata_keys = {b'scaled' , b'scaler_finalized' , b'length' }
940+ metadata_keys = {b'scaled' , b'scaler_finalized' , b'length' , b'processed_source_keys' }
923941 with env .begin (write = True ) as txn :
924942 cursor = txn .cursor ()
925943 for key , value in cursor :
@@ -928,17 +946,20 @@ def merge_shards(
928946 continue
929947
930948 try :
931- # Deserialize: pickle.loads returns bytes, then deserialize to PyG HeteroData
932- serialized_bytes = pickle .loads (value )
933- graph = load_graph_from_serialized (serialized_bytes )
949+ # Deserialize: pickle.loads may return dict or raw bytes depending on format
950+ raw = pickle .loads (value )
951+ if isinstance (raw , dict ):
952+ graph = load_graph_from_serialized (raw ["molecule_graph" ])
953+ else :
954+ graph = load_graph_from_serialized (raw )
934955
935956 # Apply scalers - feature scaler expects a list
936957 graph = merged_feature_scaler ([graph ])
937958 graph = merged_label_scaler (graph )
938959
939960 # Serialize and write back
940961 serialized_bytes = serialize_graph (graph [0 ], ret = True )
941- txn .put (key , pickle .dumps (serialized_bytes , protocol = - 1 ))
962+ txn .put (key , pickle .dumps ({ "molecule_graph" : serialized_bytes } , protocol = - 1 ))
942963 count += 1
943964 except Exception as e :
944965 logger .warning (f"Failed to scale graph { key } : { e } " )
@@ -1123,9 +1144,10 @@ def process(
11231144 self .feature_scaler_iterative .update ([first_graph ])
11241145 self .label_scaler_iterative .update ([first_graph ])
11251146 write_buffer .append ((
1126- f"{ key_str } " .encode ("ascii" ),
1127- pickle .dumps (serialize_graph (first_graph , ret = True ), protocol = - 1 ),
1147+ f"{ processed_count } " .encode ("ascii" ),
1148+ pickle .dumps ({ "molecule_graph" : serialize_graph (first_graph , ret = True )} , protocol = - 1 ),
11281149 ))
1150+ self ._processed_source_keys .add (key_str )
11291151 processed_count += 1
11301152 first_key_idx = idx + 1
11311153 break
@@ -1162,9 +1184,10 @@ def process_key(key):
11621184 self .label_scaler_iterative .update ([graph ])
11631185
11641186 write_buffer .append ((
1165- f"{ key_str } " .encode ("ascii" ),
1166- pickle .dumps (serialize_graph (graph , ret = True ), protocol = - 1 ),
1187+ f"{ processed_count } " .encode ("ascii" ),
1188+ pickle .dumps ({ "molecule_graph" : serialize_graph (graph , ret = True )} , protocol = - 1 ),
11671189 ))
1190+ self ._processed_source_keys .add (key_str )
11681191 processed_count += 1
11691192
11701193 if len (write_buffer ) >= self .batch_size :
@@ -1379,9 +1402,10 @@ def process(
13791402 self .feature_scaler_iterative .update ([first_graph ])
13801403 self .label_scaler_iterative .update ([first_graph ])
13811404 write_buffer .append ((
1382- f"{ key_str } " .encode ("ascii" ),
1383- pickle .dumps (serialize_graph (first_graph , ret = True ), protocol = - 1 ),
1405+ f"{ processed_count } " .encode ("ascii" ),
1406+ pickle .dumps ({ "molecule_graph" : serialize_graph (first_graph , ret = True )} , protocol = - 1 ),
13841407 ))
1408+ self ._processed_source_keys .add (key_str )
13851409 processed_count += 1
13861410 first_key_idx = idx + 1
13871411 break
@@ -1418,9 +1442,10 @@ def process_key(key):
14181442 self .label_scaler_iterative .update ([graph ])
14191443
14201444 write_buffer .append ((
1421- f"{ key_str } " .encode ("ascii" ),
1422- pickle .dumps (serialize_graph (graph , ret = True ), protocol = - 1 ),
1445+ f"{ processed_count } " .encode ("ascii" ),
1446+ pickle .dumps ({ "molecule_graph" : serialize_graph (graph , ret = True )} , protocol = - 1 ),
14231447 ))
1448+ self ._processed_source_keys .add (key_str )
14241449 processed_count += 1
14251450
14261451 if len (write_buffer ) >= self .batch_size :
@@ -2040,9 +2065,10 @@ def process(
20402065 self .feature_scaler_iterative .update ([first_graph ])
20412066 self .label_scaler_iterative .update ([first_graph ])
20422067 write_buffer .append ((
2043- f"{ key_str } " .encode ("ascii" ),
2044- pickle .dumps (serialize_graph (first_graph , ret = True ), protocol = - 1 ),
2068+ f"{ processed_count } " .encode ("ascii" ),
2069+ pickle .dumps ({ "molecule_graph" : serialize_graph (first_graph , ret = True )} , protocol = - 1 ),
20452070 ))
2071+ self ._processed_source_keys .add (key_str )
20462072 processed_count += 1
20472073 first_key_idx = idx + 1
20482074 self .logger .info (f"Successfully initialized grapher with key { key_str } " )
@@ -2087,9 +2113,10 @@ def process_key(key):
20872113 self .label_scaler_iterative .update ([graph ])
20882114
20892115 write_buffer .append ((
2090- f"{ key_str } " .encode ("ascii" ),
2091- pickle .dumps (serialize_graph (graph , ret = True ), protocol = - 1 ),
2116+ f"{ processed_count } " .encode ("ascii" ),
2117+ pickle .dumps ({ "molecule_graph" : serialize_graph (graph , ret = True )} , protocol = - 1 ),
20922118 ))
2119+ self ._processed_source_keys .add (key_str )
20932120 processed_count += 1
20942121
20952122 # Batch commit
0 commit comments