Skip to content

Commit 9d9fd12

Browse files
committed
limit threads during parallel generation
1 parent 9e05cfb commit 9d9fd12

3 files changed

Lines changed: 24 additions & 30 deletions

File tree

plurel/bipartite.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,10 @@ def get_bipartite_hsbm(size_a: int, size_b: int, hierarchy_a: list, hierarchy_b:
8181
)
8282

8383
for b_idx, b_node in tqdm(enumerate(nodes_b), desc="adding edges in bi_hsbm", leave=False):
84-
probs = np.array(
85-
[
86-
get_nodes_connect_prob(
87-
node_idx_a=a_idx,
88-
node_idx_b=b_idx,
89-
probs_at_levels=probs_at_levels,
90-
cluster_at_levels_a=cluster_at_levels_a,
91-
cluster_at_levels_b=cluster_at_levels_b,
92-
)
93-
for a_idx in range(size_a)
94-
]
95-
)
84+
probs = np.ones(size_a)
85+
for l_idx in range(len(probs_at_levels)):
86+
cluster_b = cluster_at_levels_b[b_idx, l_idx]
87+
probs *= probs_at_levels[l_idx][cluster_at_levels_a[:, l_idx], cluster_b]
9688
try:
9789
probs = probs / probs.sum()
9890
except ValueError:

plurel/scm.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,8 @@ def __init__(
342342
self.validate_foreign_scms()
343343
self.initialize_dag()
344344
self.initialize_nodes_and_edges()
345+
self._topological_generations = list(nx.topological_generations(self.dag.graph))
346+
self._collation_cache: dict[str, list[tuple]] = {}
345347

346348
def initialize_dag(self):
347349
dag_class = self.scm_params.scm_layout_choices.sample_uniform()
@@ -453,8 +455,7 @@ def propagate(self, row_idx: int, foreign_row_idxs: list[int], foreign_scms: lis
453455
)
454456
foreign_scms_row_embds.append(foreign_row_embds)
455457

456-
topological_gens = nx.topological_generations(self.dag.graph)
457-
for gen in topological_gens:
458+
for gen in self._topological_generations:
458459
for node in gen:
459460
node_stype = self.dag.graph.nodes[node]["_stype"]
460461
if node in self.source_nodes:
@@ -503,7 +504,7 @@ def generate_row(self, row_idx: int):
503504
foreign_scm = self.foreign_scm_info[foreign_table_name]
504505
foreign_scms.append(foreign_scm)
505506
bi_g = self.bi_fk_pk_graph_map[foreign_table_name]
506-
parent_node_name = list(bi_g.in_edges(f"b{row_idx}"))[0][0]
507+
parent_node_name = next(iter(bi_g.predecessors(f"b{row_idx}")))
507508
foreign_row_idx = bi_g.nodes[parent_node_name]["node_idx"]
508509
foreign_row_idxs.append(foreign_row_idx)
509510
row[fkey_col] = foreign_row_idx
@@ -589,24 +590,23 @@ def generate_df(
589590
return self.df
590591

591592
def collate_feature_embeddings(self, row_idx: int, child_table_name: int):
592-
col_to_stype = {}
593-
col_to_num_categories = {}
594-
col_to_collation_encoder = {}
595-
for node in sorted(self.col_nodes):
596-
col_name = self.dag.graph.nodes[node]["col_name"]
597-
col_to_stype[col_name] = self.dag.graph.nodes[node]["_stype"]
598-
col_to_num_categories[col_name] = self.dag.graph.nodes[node]["num_categories"]
599-
col_to_collation_encoder[col_name] = self.dag.graph.nodes[node]["collation_encoders"][
600-
(self.table_name, child_table_name)
593+
if child_table_name not in self._collation_cache:
594+
# (col_name, _stype, encoder) — stable across all rows for this child_table_name
595+
self._collation_cache[child_table_name] = [
596+
(
597+
self.dag.graph.nodes[node]["col_name"],
598+
self.dag.graph.nodes[node]["_stype"],
599+
self.dag.graph.nodes[node]["collation_encoders"][
600+
(self.table_name, child_table_name)
601+
],
602+
)
603+
for node in sorted(self.col_nodes)
601604
]
602-
row = self.df.iloc[row_idx].to_dict()
605+
col_entries = self._collation_cache[child_table_name]
603606
row_embds = []
604607
# for p->f embedding propagation
605-
for col_name, value in row.items():
606-
if col_name not in col_to_stype:
607-
continue
608-
_stype = col_to_stype[col_name]
608+
for col_name, _stype, encoder in col_entries:
609+
value = self.df.at[row_idx, col_name]
609610
value_tensor = self.strategy.tensorize_col(value=value, _stype=_stype)
610-
encoder = col_to_collation_encoder[col_name]
611611
row_embds.append(encoder(value_tensor).squeeze())
612612
return row_embds

scripts/synthetic_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from multiprocessing import Pool, cpu_count
55
from pathlib import Path
66

7+
import torch
78
from tqdm import tqdm
89

910
from plurel.config import Config
@@ -12,6 +13,7 @@
1213

1314

1415
def generate_rel_synthetic_db(seed: int, preprocess: bool = False):
16+
torch.set_num_threads(1)
1517
set_random_seed(0)
1618
db_name = f"rel-synthetic-{seed}"
1719
print(f"Creating dataset: {db_name}")

0 commit comments

Comments
 (0)