Skip to content

Commit c719d5b

Browse files
committed
update symmetric membrane implementation for flat sterol
1 parent 608d4fc commit c719d5b

File tree

1 file changed

+33
-21
lines changed

1 file changed

+33
-21
lines changed

domhmm/analysis/domhmm.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _conclude(self):
406406
self.predict_states()
407407

408408
# Validate states and result prediction
409-
self.state_validate()
409+
# self.state_validate()
410410
if self.result_plots:
411411
# Plot prediction result
412412
self.predict_plot()
@@ -535,7 +535,11 @@ def GMM(self, gmm_kwargs):
535535
log.info(f"Leaflet {leaflet}, {res} Gaussian Mixture Model is trained.")
536536
self.results["GMM"][res] = temp_dict
537537
else:
538-
gmm = mixture.GaussianMixture(n_components=2, **gmm_kwargs).fit(data[1].reshape(-1, data[1].shape[2]))
538+
gmm_data = data[1]
539+
features = gmm_data.shape[2]
540+
if res in self.sterol_heads.keys():
541+
gmm_data = gmm_data[~np.isnan(gmm_data)]
542+
gmm = mixture.GaussianMixture(n_components=2, **gmm_kwargs).fit(gmm_data.reshape(-1, features))
539543
self.results["GMM"][res] = gmm
540544
log.info(f"{res} Gaussian Mixture Model is trained.")
541545

@@ -838,7 +842,11 @@ def HMM(self, hmm_kwargs):
838842
log.info(f"Leaflet {leaflet}, {resname} Gaussian Hidden Markov Model is trained.")
839843
self.results["HMM"][resname] = temp_dict
840844
else:
841-
hmm = self.fit_hmm(data=data[1], gmm=self.results["GMM"][resname], hmm_kwargs=hmm_kwargs,
845+
hmm_data = data[1]
846+
if resname in self.sterol_heads.keys():
847+
features = hmm_data.shape[2]
848+
hmm_data = hmm_data[~np.isnan(hmm_data)].reshape(-1, features)
849+
hmm = self.fit_hmm(data=hmm_data, gmm=self.results["GMM"][resname], hmm_kwargs=hmm_kwargs,
842850
n_repeats=self.n_init_hmm)
843851
self.results["HMM"][resname] = hmm
844852
log.info(f"{resname} Gaussian Hidden Markov Model is trained.")
@@ -848,7 +856,7 @@ def HMM(self, hmm_kwargs):
848856
# Make predictions based on HMM model
849857
self.predict_states()
850858
# Validate states and result prediction
851-
self.state_validate()
859+
# self.state_validate()
852860
if self.result_plots:
853861
# Plot prediction result
854862
self.predict_plot()
@@ -1035,10 +1043,6 @@ def predict_states(self):
10351043
mask_flats = np.isnan(data[:, :, 0])
10361044
# Just assign 0 to all NaNs and change prediction to 0 (disordered) later
10371045
data = np.nan_to_num(data, nan=0)
1038-
# data[:,:,0][mask_apl] = 200.0
1039-
# # Changing scc NaN to -2 for disordered prediction
1040-
# mask_scc = np.isnan(data[:, :, 1])
1041-
# data[:,:,1][mask_scc] = -2.0
10421046
lengths = np.repeat(shape[1], shape[0])
10431047
prediction = hmm.predict(data.reshape(-1, shape[2]), lengths=lengths).reshape(shape[0],
10441048
shape[1])
@@ -1064,25 +1068,33 @@ def predict_states(self):
10641068
else:
10651069
# Symmetric membrane case
10661070
for resname, data in self.results.train_data_per_type.items():
1067-
shape = data[1].shape
1071+
predict_data = data[1]
1072+
shape = predict_data.shape
10681073
hmm = self.results['HMM'][resname]
1074+
# Changing APL NaN to 200 for disordered prediction
1075+
mask_flats = np.isnan(predict_data[:, :, 0])
1076+
# Just assign 0 to all NaNs and change prediction to 0 (disordered) later
1077+
predict_data = np.nan_to_num(predict_data, nan=0)
10691078
# Lengths consists of number of frames and number of residues
10701079
lengths = np.repeat(shape[1], shape[0])
1071-
prediction = hmm.predict(data[1].reshape(-1, shape[2]), lengths=lengths).reshape(shape[0], shape[1])
1080+
prediction = hmm.predict(predict_data.reshape(-1, shape[2]), lengths=lengths).reshape(shape[0], shape[1])
1081+
prediction = self.hmm_diff_checker(hmm.means_, prediction)
1082+
# Change flat sterol predictions to 0 (disordered)
1083+
prediction[mask_flats] = 0
10721084
# Save prediction result of each residue
10731085
self.results['HMM_Pred'][resname] = prediction
10741086

1075-
def state_validate(self):
1076-
"""
1077-
Validate state assignments of HMM model by checking means of the model of each residue.
1078-
"""
1079-
if not self.asymmetric_membrane:
1080-
# Asymmetric membrane validation is done in prediction step due to nature of it
1081-
for resname, hmm in self.results["HMM"].items():
1082-
if hmm is not None:
1083-
means = hmm.means_
1084-
prediction_results = self.results['HMM_Pred'][resname]
1085-
self.results['HMM_Pred'][resname] = self.hmm_diff_checker(means, prediction_results)
1087+
# def state_validate(self):
1088+
# """
1089+
# Validate state assignments of HMM model by checking means of the model of each residue.
1090+
# """
1091+
# if not self.asymmetric_membrane:
1092+
# # Asymmetric membrane validation is done in prediction step due to nature of it
1093+
# for resname, hmm in self.results["HMM"].items():
1094+
# if hmm is not None:
1095+
# means = hmm.means_
1096+
# prediction_results = self.results['HMM_Pred'][resname]
1097+
# self.results['HMM_Pred'][resname] = self.hmm_diff_checker(means, prediction_results)
10861098

10871099
@staticmethod
10881100
def hmm_diff_checker(means, prediction_results):

0 commit comments

Comments
 (0)