Skip to content

Commit b6071b4

Browse files
committed
add parallel clustering option to base
1 parent 9d53f05 commit b6071b4

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

domhmm/analysis/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class LeafletAnalysisBase(AnalysisBase):
7676
User-specific HMM (e.g., pre-trained on another simulation)
7777
do_clustering: bool
7878
Perform the hierarchical clustering for each frame
79+
parallel_clustering: bool
80+
Perform the hierarchical clustering in parallel
7981
n_init_hmm: int
8082
Number of repeats for HMM model trainings
8183
@@ -128,6 +130,7 @@ def __init__(
128130
n_init_hmm: int = 2,
129131
save_plots: bool = False,
130132
do_clustering: bool = True,
133+
parallel_clustering: bool = False,
131134
**kwargs
132135
):
133136
# the below line must be kept to initialize the AnalysisBase class!
@@ -158,6 +161,7 @@ def __init__(
158161
self.n_init_hmm = n_init_hmm
159162
self.save_plots = save_plots
160163
self.do_clustering = do_clustering
164+
self.parallel_clustering = parallel_clustering
161165

162166
assert heads.keys() == tails.keys(), "Heads and tails don't contain same residue names"
163167

domhmm/analysis/domhmm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ def _conclude(self):
404404
pass
405405
else:
406406
log.info("Clustering is starting.")
407-
self.result_clustering()
407+
if self.parallel_clustering:
408+
self.result_clustering_parallel()
409+
else:
410+
self.result_clustering_serial()
408411

409412
if self.result_plots:
410413
self.clustering_plot()
@@ -1450,7 +1453,7 @@ def _process_frame_leaflet(args):
14501453

14511454
return (j, frame_number, cluster_result)
14521455

1453-
def result_clustering(self):
1456+
def result_clustering_parallel(self):
14541457
"""
14551458
Runs hierarchical clustering for each frame and saves result (parallelized).
14561459
"""
@@ -1472,7 +1475,7 @@ def result_clustering(self):
14721475
}
14731476
tasks.append((i, j, frame_data))
14741477

1475-
print(f"Total CPU count is {mp.cpu_count()}")
1478+
log.info(f"{mp.cpu_count()} CPU cores will be used for hierarchical clustering")
14761479
with mp.Pool(processes=mp.cpu_count()) as pool:
14771480
results = list(tqdm(pool.imap(self._process_frame_leaflet, tasks), total=len(tasks)))
14781481

0 commit comments

Comments
 (0)