@@ -815,59 +815,75 @@ def build(self, psms: List) -> None:
815815 f"Insufficient high-scoring PSMs to train CID model (need at least { self .config .get ('min_num_psms_model' , 50 )} , actual { len (psms )} )"
816816 )
817817
818- # Group PSMs by charge state
819- charge_psms = defaultdict (list )
820- for psm in psms :
821- charge_psms [psm .charge ].append (psm )
822-
823- # Check if each charge state has sufficient PSMs for modeling
824- min_psms_per_charge = self .config .get ("min_num_psms_model" , 50 )
825- bad_charges = set ()
826-
827- logger .info ("PSMs for modeling:" )
828- logger .info ("------------------" )
829- for charge in sorted (charge_psms .keys ()):
830- n_psms = len (charge_psms [charge ])
831- logger .info (f"+{ charge } : { n_psms } PSMs" )
832- if n_psms < min_psms_per_charge :
833- bad_charges .add (charge )
834- logger .warning (
835- f"Charge state +{ charge } has insufficient PSMs for modeling (need { min_psms_per_charge } , got { n_psms } )"
836- )
837-
838- # Remove charge states with insufficient PSMs
839- for bad_charge in bad_charges :
840- del charge_psms [bad_charge ]
841-
842- if not charge_psms :
843- raise RuntimeError (
844- f"No charge states have sufficient PSMs for modeling (minimum { min_psms_per_charge } PSMs per charge state required)"
845- )
818+ # Check if charge splitting is disabled
819+ disable_split_by_charge = self .config .get ("disable_split_by_charge" , False )
846820
847- logger .info (
848- f"Will build models for charge states: { sorted (charge_psms .keys ())} "
849- )
821+ if disable_split_by_charge :
822+ # Train a single global model using all PSMs
823+ logger .info ("Training global CID model (charge splitting disabled)" )
824+ logger .info (f"Total PSMs for modeling: { len (psms )} " )
850825
851- # Get thread count configuration
852- num_threads = self .config .get ("num_threads" , os .cpu_count () or 4 )
853- logger .info (f"Using { num_threads } threads for CID model training..." )
826+ # Use a dummy charge state (e.g., 0) to represent the global model
827+ global_charge = 0
828+ charge , model_data = self ._build_charge_model (global_charge , psms )
829+ self .charge_models [global_charge ] = model_data
854830
855- # Parallel processing for each charge state
856- with ThreadPoolExecutor (max_workers = num_threads ) as executor :
857- futures = []
858- for charge , charge_psm_list in charge_psms .items ():
859- future = executor .submit (
860- self ._build_charge_model , charge , charge_psm_list
831+ logger .info (f"Global CID model trained with { len (psms )} PSMs" )
832+ else :
833+ # Original behavior: split by charge state
834+ # Group PSMs by charge state
835+ charge_psms = defaultdict (list )
836+ for psm in psms :
837+ charge_psms [psm .charge ].append (psm )
838+
839+ # Check if each charge state has sufficient PSMs for modeling
840+ min_psms_per_charge = self .config .get ("min_num_psms_model" , 50 )
841+ bad_charges = set ()
842+
843+ logger .info ("PSMs for modeling:" )
844+ logger .info ("------------------" )
845+ for charge in sorted (charge_psms .keys ()):
846+ n_psms = len (charge_psms [charge ])
847+ logger .info (f"+{ charge } : { n_psms } PSMs" )
848+ if n_psms < min_psms_per_charge :
849+ bad_charges .add (charge )
850+ logger .warning (
851+ f"Charge state +{ charge } has insufficient PSMs for modeling (need { min_psms_per_charge } , got { n_psms } )"
852+ )
853+
854+ # Remove charge states with insufficient PSMs
855+ for bad_charge in bad_charges :
856+ del charge_psms [bad_charge ]
857+
858+ if not charge_psms :
859+ raise RuntimeError (
860+ f"No charge states have sufficient PSMs for modeling (minimum { min_psms_per_charge } PSMs per charge state required)"
861861 )
862- futures .append (future )
863862
864- # Wait for all tasks to complete
865- for future in as_completed (futures ):
866- try :
867- charge , model_data = future .result ()
868- self .charge_models [charge ] = model_data
869- except Exception as e :
870- logger .error (f"CID model training error: { str (e )} " )
863+ logger .info (
864+ f"Will build models for charge states: { sorted (charge_psms .keys ())} "
865+ )
866+
867+ # Get thread count configuration
868+ num_threads = self .config .get ("num_threads" , os .cpu_count () or 4 )
869+ logger .info (f"Using { num_threads } threads for CID model training..." )
870+
871+ # Parallel processing for each charge state
872+ with ThreadPoolExecutor (max_workers = num_threads ) as executor :
873+ futures = []
874+ for charge , charge_psm_list in charge_psms .items ():
875+ future = executor .submit (
876+ self ._build_charge_model , charge , charge_psm_list
877+ )
878+ futures .append (future )
879+
880+ # Wait for all tasks to complete
881+ for future in as_completed (futures ):
882+ try :
883+ charge , model_data = future .result ()
884+ self .charge_models [charge ] = model_data
885+ except Exception as e :
886+ logger .error (f"CID model training error: { str (e )} " )
871887
872888 def _build_charge_model (
873889 self , charge : int , charge_psm_list : List
@@ -937,6 +953,11 @@ def get_charge_model(self, charge: int) -> Optional[ModelData_CID]:
937953 Returns:
938954 ModelData_CID instance, returns None if not found
939955 """
956+ # Check if using global model (charge splitting disabled)
957+ if 0 in self .charge_models :
958+ # Global model exists, use it for all charges
959+ return self .charge_models [0 ]
960+
940961 # First try to get the exact charge model
941962 if charge in self .charge_models :
942963 return self .charge_models [charge ]
@@ -1057,47 +1078,63 @@ def build(self, psms: List) -> None:
10571078 f"Insufficient high-scoring PSMs to train HCD model (need at least { self .config .get ('min_num_psms_model' , 50 )} , actual { len (psms )} )"
10581079 )
10591080
1060- charge_psms = defaultdict (list )
1061- for psm in psms :
1062- charge_psms [psm .charge ].append (psm )
1063-
1064- # Check if each charge state has sufficient PSMs for modeling
1065- min_psms_per_charge = self .config .get ("min_num_psms_model" , 50 )
1066- bad_charges = set ()
1067-
1068- logger .info ("PSMs for modeling:" )
1069- logger .info ("------------------" )
1070- for charge in sorted (charge_psms .keys ()):
1071- n_psms = len (charge_psms [charge ])
1072- logger .info (f"+{ charge } : { n_psms } PSMs" )
1073- if n_psms < min_psms_per_charge :
1074- bad_charges .add (charge )
1075- logger .warning (
1076- f"Charge state +{ charge } has insufficient PSMs for modeling (need { min_psms_per_charge } , got { n_psms } )"
1077- )
1081+ # Check if charge splitting is disabled
1082+ disable_split_by_charge = self .config .get ("disable_split_by_charge" , False )
10781083
1079- # Remove charge states with insufficient PSMs
1080- for bad_charge in bad_charges :
1081- del charge_psms [bad_charge ]
1084+ if disable_split_by_charge :
1085+ # Train a single global model using all PSMs
1086+ logger .info ("Training global HCD model (charge splitting disabled)" )
1087+ logger .info (f"Total PSMs for modeling: { len (psms )} " )
10821088
1083- if not charge_psms :
1084- raise RuntimeError (
1085- f"No charge states have sufficient PSMs for modeling (minimum { min_psms_per_charge } PSMs per charge state required)"
1086- )
1089+ # Use a dummy charge state (e.g., 0) to represent the global model
1090+ global_charge = 0
1091+ charge , model_data = self . _build_charge_model ( global_charge , psms )
1092+ self . charge_models [ global_charge ] = model_data
10871093
1088- logger .info (
1089- f"Will build models for charge states: { sorted (charge_psms .keys ())} "
1090- )
1094+ logger .info (f"Global HCD model trained with { len (psms )} PSMs" )
1095+ else :
1096+ # Original behavior: split by charge state
1097+ charge_psms = defaultdict (list )
1098+ for psm in psms :
1099+ charge_psms [psm .charge ].append (psm )
1100+
1101+ # Check if each charge state has sufficient PSMs for modeling
1102+ min_psms_per_charge = self .config .get ("min_num_psms_model" , 50 )
1103+ bad_charges = set ()
1104+
1105+ logger .info ("PSMs for modeling:" )
1106+ logger .info ("------------------" )
1107+ for charge in sorted (charge_psms .keys ()):
1108+ n_psms = len (charge_psms [charge ])
1109+ logger .info (f"+{ charge } : { n_psms } PSMs" )
1110+ if n_psms < min_psms_per_charge :
1111+ bad_charges .add (charge )
1112+ logger .warning (
1113+ f"Charge state +{ charge } has insufficient PSMs for modeling (need { min_psms_per_charge } , got { n_psms } )"
1114+ )
1115+
1116+ # Remove charge states with insufficient PSMs
1117+ for bad_charge in bad_charges :
1118+ del charge_psms [bad_charge ]
1119+
1120+ if not charge_psms :
1121+ raise RuntimeError (
1122+ f"No charge states have sufficient PSMs for modeling (minimum { min_psms_per_charge } PSMs per charge state required)"
1123+ )
10911124
1092- num_threads = self .config .get ("num_threads" , os .cpu_count () or 4 )
1093- logger .info (f"Using { num_threads } threads for HCD model training..." )
1125+ logger .info (
1126+ f"Will build models for charge states: { sorted (charge_psms .keys ())} "
1127+ )
10941128
1095- with ThreadPoolExecutor (max_workers = num_threads ) as executor :
1096- futures = []
1097- for charge , charge_psm_list in charge_psms .items ():
1098- future = executor .submit (
1099- self ._build_charge_model , charge , charge_psm_list
1100- )
1129+ num_threads = self .config .get ("num_threads" , os .cpu_count () or 4 )
1130+ logger .info (f"Using { num_threads } threads for HCD model training..." )
1131+
1132+ with ThreadPoolExecutor (max_workers = num_threads ) as executor :
1133+ futures = []
1134+ for charge , charge_psm_list in charge_psms .items ():
1135+ future = executor .submit (
1136+ self ._build_charge_model , charge , charge_psm_list
1137+ )
11011138 futures .append (future )
11021139
11031140 for future in as_completed (futures ):
@@ -1177,6 +1214,11 @@ def get_charge_model(self, charge: int) -> Optional[ModelData_HCD]:
11771214 Returns:
11781215 ModelData_HCD instance, returns None if not found
11791216 """
1217+ # Check if using global model (charge splitting disabled)
1218+ if 0 in self .charge_models :
1219+ # Global model exists, use it for all charges
1220+ return self .charge_models [0 ]
1221+
11801222 # First try to get the exact charge model
11811223 if charge in self .charge_models :
11821224 return self .charge_models [charge ]
0 commit comments