@@ -342,6 +342,8 @@ def __init__(
342342 self .validate_foreign_scms ()
343343 self .initialize_dag ()
344344 self .initialize_nodes_and_edges ()
345+ self ._topological_generations = list (nx .topological_generations (self .dag .graph ))
346+ self ._collation_cache : dict [str , list [tuple ]] = {}
345347
346348 def initialize_dag (self ):
347349 dag_class = self .scm_params .scm_layout_choices .sample_uniform ()
@@ -453,8 +455,7 @@ def propagate(self, row_idx: int, foreign_row_idxs: list[int], foreign_scms: lis
453455 )
454456 foreign_scms_row_embds .append (foreign_row_embds )
455457
456- topological_gens = nx .topological_generations (self .dag .graph )
457- for gen in topological_gens :
458+ for gen in self ._topological_generations :
458459 for node in gen :
459460 node_stype = self .dag .graph .nodes [node ]["_stype" ]
460461 if node in self .source_nodes :
@@ -503,7 +504,7 @@ def generate_row(self, row_idx: int):
503504 foreign_scm = self .foreign_scm_info [foreign_table_name ]
504505 foreign_scms .append (foreign_scm )
505506 bi_g = self .bi_fk_pk_graph_map [foreign_table_name ]
506- parent_node_name = list ( bi_g .in_edges (f"b{ row_idx } " ))[ 0 ][ 0 ]
507+ parent_node_name = next ( iter ( bi_g .predecessors (f"b{ row_idx } " )))
507508 foreign_row_idx = bi_g .nodes [parent_node_name ]["node_idx" ]
508509 foreign_row_idxs .append (foreign_row_idx )
509510 row [fkey_col ] = foreign_row_idx
@@ -589,24 +590,23 @@ def generate_df(
589590 return self .df
590591
591592 def collate_feature_embeddings (self , row_idx : int , child_table_name : int ):
592- col_to_stype = {}
593- col_to_num_categories = {}
594- col_to_collation_encoder = {}
595- for node in sorted (self .col_nodes ):
596- col_name = self .dag .graph .nodes [node ]["col_name" ]
597- col_to_stype [col_name ] = self .dag .graph .nodes [node ]["_stype" ]
598- col_to_num_categories [col_name ] = self .dag .graph .nodes [node ]["num_categories" ]
599- col_to_collation_encoder [col_name ] = self .dag .graph .nodes [node ]["collation_encoders" ][
600- (self .table_name , child_table_name )
593+ if child_table_name not in self ._collation_cache :
594+ # (col_name, _stype, encoder) — stable across all rows for this child_table_name
595+ self ._collation_cache [child_table_name ] = [
596+ (
597+ self .dag .graph .nodes [node ]["col_name" ],
598+ self .dag .graph .nodes [node ]["_stype" ],
599+ self .dag .graph .nodes [node ]["collation_encoders" ][
600+ (self .table_name , child_table_name )
601+ ],
602+ )
603+ for node in sorted (self .col_nodes )
601604 ]
602- row = self .df . iloc [ row_idx ]. to_dict ()
605+ col_entries = self ._collation_cache [ child_table_name ]
603606 row_embds = []
604607 # for p->f embedding propagation
605- for col_name , value in row .items ():
606- if col_name not in col_to_stype :
607- continue
608- _stype = col_to_stype [col_name ]
608+ for col_name , _stype , encoder in col_entries :
609+ value = self .df .at [row_idx , col_name ]
609610 value_tensor = self .strategy .tensorize_col (value = value , _stype = _stype )
610- encoder = col_to_collation_encoder [col_name ]
611611 row_embds .append (encoder (value_tensor ).squeeze ())
612612 return row_embds
0 commit comments