@@ -406,7 +406,7 @@ def _conclude(self):
406406 self .predict_states ()
407407
408408 # Validate states and result prediction
409- self .state_validate ()
409+ # self.state_validate()
410410 if self .result_plots :
411411 # Plot prediction result
412412 self .predict_plot ()
@@ -535,7 +535,11 @@ def GMM(self, gmm_kwargs):
535535 log .info (f"Leaflet { leaflet } , { res } Gaussian Mixture Model is trained." )
536536 self .results ["GMM" ][res ] = temp_dict
537537 else :
538- gmm = mixture .GaussianMixture (n_components = 2 , ** gmm_kwargs ).fit (data [1 ].reshape (- 1 , data [1 ].shape [2 ]))
538+ gmm_data = data [1 ]
539+ features = gmm_data .shape [2 ]
540+ if res in self .sterol_heads .keys ():
541+ gmm_data = gmm_data [~ np .isnan (gmm_data )]
542+ gmm = mixture .GaussianMixture (n_components = 2 , ** gmm_kwargs ).fit (gmm_data .reshape (- 1 , features ))
539543 self .results ["GMM" ][res ] = gmm
540544 log .info (f"{ res } Gaussian Mixture Model is trained." )
541545
@@ -838,7 +842,11 @@ def HMM(self, hmm_kwargs):
838842 log .info (f"Leaflet { leaflet } , { resname } Gaussian Hidden Markov Model is trained." )
839843 self .results ["HMM" ][resname ] = temp_dict
840844 else :
841- hmm = self .fit_hmm (data = data [1 ], gmm = self .results ["GMM" ][resname ], hmm_kwargs = hmm_kwargs ,
845+ hmm_data = data [1 ]
846+ if resname in self .sterol_heads .keys ():
847+ features = hmm_data .shape [2 ]
848+ hmm_data = hmm_data [~ np .isnan (hmm_data )].reshape (- 1 , features )
849+ hmm = self .fit_hmm (data = hmm_data , gmm = self .results ["GMM" ][resname ], hmm_kwargs = hmm_kwargs ,
842850 n_repeats = self .n_init_hmm )
843851 self .results ["HMM" ][resname ] = hmm
844852 log .info (f"{ resname } Gaussian Hidden Markov Model is trained." )
@@ -848,7 +856,7 @@ def HMM(self, hmm_kwargs):
848856 # Make predictions based on HMM model
849857 self .predict_states ()
850858 # Validate states and result prediction
851- self .state_validate ()
859+ # self.state_validate()
852860 if self .result_plots :
853861 # Plot prediction result
854862 self .predict_plot ()
@@ -1035,10 +1043,6 @@ def predict_states(self):
10351043 mask_flats = np .isnan (data [:, :, 0 ])
10361044 # Just assign 0 to all NaNs and change prediction to 0 (disordered) later
10371045 data = np .nan_to_num (data , nan = 0 )
1038- # data[:,:,0][mask_apl] = 200.0
1039- # # Changing scc NaN to -2 for disordered prediction
1040- # mask_scc = np.isnan(data[:, :, 1])
1041- # data[:,:,1][mask_scc] = -2.0
10421046 lengths = np .repeat (shape [1 ], shape [0 ])
10431047 prediction = hmm .predict (data .reshape (- 1 , shape [2 ]), lengths = lengths ).reshape (shape [0 ],
10441048 shape [1 ])
@@ -1064,25 +1068,33 @@ def predict_states(self):
10641068 else :
10651069 # Symmetric membrane case
10661070 for resname , data in self .results .train_data_per_type .items ():
1067- shape = data [1 ].shape
1071+ predict_data = data [1 ]
1072+ shape = predict_data .shape
10681073 hmm = self .results ['HMM' ][resname ]
1074+ # Changing APL NaN to 200 for disordered prediction
1075+ mask_flats = np .isnan (predict_data [:, :, 0 ])
1076+ # Just assign 0 to all NaNs and change prediction to 0 (disordered) later
1077+ predict_data = np .nan_to_num (predict_data , nan = 0 )
10691078 # Lengths consists of number of frames and number of residues
10701079 lengths = np .repeat (shape [1 ], shape [0 ])
1071- prediction = hmm .predict (data [1 ].reshape (- 1 , shape [2 ]), lengths = lengths ).reshape (shape [0 ], shape [1 ])
1080+ prediction = hmm .predict (predict_data .reshape (- 1 , shape [2 ]), lengths = lengths ).reshape (shape [0 ], shape [1 ])
1081+ prediction = self .hmm_diff_checker (hmm .means_ , prediction )
1082+ # Change flat sterol predictions to 0 (disordered)
1083+ prediction [mask_flats ] = 0
10721084 # Save prediction result of each residue
10731085 self .results ['HMM_Pred' ][resname ] = prediction
10741086
1075- def state_validate (self ):
1076- """
1077- Validate state assignments of HMM model by checking means of the model of each residue.
1078- """
1079- if not self .asymmetric_membrane :
1080- # Asymmetric membrane validation is done in prediction step due to nature of it
1081- for resname , hmm in self .results ["HMM" ].items ():
1082- if hmm is not None :
1083- means = hmm .means_
1084- prediction_results = self .results ['HMM_Pred' ][resname ]
1085- self .results ['HMM_Pred' ][resname ] = self .hmm_diff_checker (means , prediction_results )
1087+ # def state_validate(self):
1088+ # """
1089+ # Validate state assignments of HMM model by checking means of the model of each residue.
1090+ # """
1091+ # if not self.asymmetric_membrane:
1092+ # # Asymmetric membrane validation is done in prediction step due to nature of it
1093+ # for resname, hmm in self.results["HMM"].items():
1094+ # if hmm is not None:
1095+ # means = hmm.means_
1096+ # prediction_results = self.results['HMM_Pred'][resname]
1097+ # self.results['HMM_Pred'][resname] = self.hmm_diff_checker(means, prediction_results)
10861098
10871099 @staticmethod
10881100 def hmm_diff_checker (means , prediction_results ):
0 commit comments