Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions causallearn/search/ConstraintBased/PC.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def pc(
verbose: bool = False,
show_progress: bool = True,
node_names: List[str] | None = None,
max_k: int = None,
**kwargs
):
if data.shape[0] < data.shape[1]:
Expand All @@ -41,11 +42,11 @@ def pc(
return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable,
uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge,
verbose=verbose,
show_progress=show_progress, **kwargs)
show_progress=show_progress, max_k=max_k, **kwargs)
else:
return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule,
uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress, **kwargs)
show_progress=show_progress, max_k=max_k, **kwargs)


def pc_alg(
Expand All @@ -59,6 +60,7 @@ def pc_alg(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
max_k=None,
**kwargs
) -> CausalGraph:
"""
Expand Down Expand Up @@ -103,7 +105,7 @@ def pc_alg(
indep_test = CIT(data, indep_test, **kwargs)
cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable,
background_knowledge=background_knowledge, verbose=verbose,
show_progress=show_progress, node_names=node_names)
show_progress=show_progress, node_names=node_names, max_k=max_k)

if background_knowledge is not None:
orient_by_background_knowledge(cg_1, background_knowledge)
Expand Down Expand Up @@ -142,14 +144,15 @@ def mvpc_alg(
data: ndarray,
node_names: List[str] | None,
alpha: float,
indep_test: str,
indep_test: Any,
correction_name: str,
stable: bool,
uc_rule: int,
uc_priority: int,
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
max_k: int | None = None,
**kwargs,
) -> CausalGraph:
"""
Expand Down Expand Up @@ -197,14 +200,14 @@ def mvpc_alg(
start = time.time()
indep_test = CIT(data, indep_test, **kwargs)
## Step 1: detect the direct causes of missingness indicators
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable)
prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable, max_k=max_k)
# print('Finish detecting the parents of missingness indicators. ')

## Step 2:
## a) Run PC algorithm with the 1st step skeleton;
cg_pre = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable,
background_knowledge=background_knowledge,
verbose=verbose, show_progress=show_progress, node_names=node_names)
verbose=verbose, show_progress=show_progress, node_names=node_names, max_k=max_k)
if background_knowledge is not None:
orient_by_background_knowledge(cg_pre, background_knowledge)

Expand Down Expand Up @@ -251,7 +254,7 @@ def mvpc_alg(

#######################################################################################################################
## *********** Functions for Step 1 ***********
def get_parent_missingness_pairs(data: ndarray, alpha: float, indep_test, stable: bool = True) -> Dict[str, list]:
def get_parent_missingness_pairs(data: ndarray, alpha: float, indep_test, stable: bool = True, max_k: int | None = None) -> Dict[str, list]:
"""
Detect the parents of missingness indicators
If a missingness indicator has no parent, it will not be included in the result
Expand All @@ -272,7 +275,7 @@ def get_parent_missingness_pairs(data: ndarray, alpha: float, indep_test, stable
## Get the index of parents of missingness indicators
# If the missingness indicator has no parent, then it will not be collected in prt_m
for missingness_i in missingness_index:
parent_of_missingness_i = detect_parent(missingness_i, data, alpha, indep_test, stable)
parent_of_missingness_i = detect_parent(missingness_i, data, alpha, indep_test, stable, max_k=max_k)
if not isempty(parent_of_missingness_i):
parent_missingness_pairs['prt'].append(parent_of_missingness_i)
parent_missingness_pairs['m'].append(missingness_i)
Expand All @@ -299,7 +302,7 @@ def get_missingness_index(data: ndarray) -> List[int]:
return missingness_index


def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool = True) -> ndarray:
def detect_parent(r: int, data_: ndarray, alpha: float, indep_test: Any, stable: bool = True, max_k: int | None = None) -> ndarray:
"""Detect the parents of a missingness indicator
:param r: the missingness indicator
:param data_: data set (numpy ndarray)
Expand Down Expand Up @@ -334,15 +337,19 @@ def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool

no_of_var = data.shape[1]
cg = CausalGraph(no_of_var)
cg.set_ind_test(CIT(data, indep_test.method))
cg.set_ind_test(indep_test)


node_ids = range(no_of_var)
pair_of_variables = list(permutations(node_ids, 2))

depth = -1
while cg.max_degree() - 1 > depth:
depth += 1
if max_k is not None and depth > max_k:
break
edge_removal = []

for (x, y) in pair_of_variables:

## *********** Adaptation 2 ***********
Expand Down Expand Up @@ -495,3 +502,4 @@ def matrix_diff(cg1: CausalGraph, cg2: CausalGraph) -> (float, List[Tuple[int, i
diff_ls.append((i, j))
count += 1
return count / 2, diff_ls

4 changes: 3 additions & 1 deletion causallearn/utils/PCUtils/SkeletonDiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def skeleton_discovery(
background_knowledge: BackgroundKnowledge | None = None,
verbose: bool = False,
show_progress: bool = True,
node_names: List[str] | None = None,
node_names: List[str] | None = None, max_k=None,
) -> CausalGraph:
"""
Perform skeleton discovery
Expand Down Expand Up @@ -63,6 +63,8 @@ def skeleton_discovery(
pbar = tqdm(total=no_of_var) if show_progress else None
while cg.max_degree() - 1 > depth:
depth += 1
if max_k is not None and depth > max_k:
break
edge_removal = []
if show_progress:
pbar.reset()
Expand Down