55from __future__ import absolute_import , division , print_function
66
77import os
8- import pandas as pd
8+ import pathlib
9+
910import numpy as np
11+ import pandas as pd
1012from rdkit import Chem
11- from ..utils import logger
12- import pathlib
1313from rdkit .Chem .Scaffolds import MurckoScaffold
1414
15+ from ..utils import logger
16+
17+
1518class MolDataReader (object ):
1619 '''A class to read Mol Data.'''
20+
1721 def read_data (self , data = None , is_train = True , ** params ):
18- # TO DO
22+ # TO DO
1923 # 1. add anomaly detection & outlier removal.
2024 # 2. add support for other file format.
2125 # 3. add support for multi tasks.
@@ -26,7 +30,7 @@ def read_data(self, data=None, is_train=True, **params):
2630 1. if target_cols is not None, use target_cols as target columns.
2731 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns.
2832 3. use given target_cols as target columns placeholder with value -1.0 for predict
29-
33+
3034 :param data: The input molecular data. Can be a file path (str), a dictionary, or a list of SMILES strings.
3135 :param is_train: (bool) A flag indicating if the operation is for training. Determines data processing steps.
3236 :param params: A dictionary of additional parameters for data processing.
@@ -50,21 +54,21 @@ def read_data(self, data=None, is_train=True, **params):
5054 # load from dict
5155 if 'target' in data :
5256 label = np .array (data ['target' ])
53- if len (label .shape )== 1 or label .shape [1 ] == 1 :
57+ if len (label .shape ) == 1 or label .shape [1 ] == 1 :
5458 data [target_col_prefix ] = label .reshape (- 1 )
5559 else :
5660 for i in range (label .shape [1 ]):
57- data [target_col_prefix + str (i )] = label [:,i ]
61+ data [target_col_prefix + str (i )] = label [:, i ]
5862
5963 _ = data .pop ('target' , None )
6064 data = pd .DataFrame (data ).rename (columns = {smiles_col : 'SMILES' })
61-
65+
6266 elif isinstance (data , list ) or isinstance (data , np .ndarray ):
6367 # load from smiles list
6468 data = pd .DataFrame (data , columns = ['SMILES' ])
6569 else :
6670 raise ValueError ('Unknown data type: {}' .format (type (data )))
67-
71+
6872 #### parsing target columns
6973 #### 1. if target_cols is not None, use target_cols as target columns.
7074 #### 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns.
@@ -77,37 +81,45 @@ def read_data(self, data=None, is_train=True, **params):
7781 multiclass_cnt = None
7882 else :
7983 if target_cols is None :
80- target_cols = [item for item in data .columns if item .startswith (target_col_prefix )]
84+ target_cols = [
85+ item for item in data .columns if item .startswith (target_col_prefix )
86+ ]
8187 elif isinstance (target_cols , str ):
8288 target_cols = target_cols .split (',' )
8389 elif isinstance (target_cols , list ):
8490 pass
8591 else :
86- raise ValueError ('Unknown target_cols type: {}' .format (type (target_cols )))
87-
92+ raise ValueError (
93+ 'Unknown target_cols type: {}' .format (type (target_cols ))
94+ )
95+
8896 if is_train :
8997 if anomaly_clean :
90- data = self .anomaly_clean (data , task , target_cols )
98+ data = self .anomaly_clean (data , task , target_cols )
9199 if task == 'multiclass' :
92100 multiclass_cnt = int (data [target_cols ].max () + 1 )
93101 else :
94102 for col in target_cols :
95103 if col not in data .columns or data [col ].isnull ().any ():
96104 data [col ] = - 1.0
97-
105+
98106 targets = data [target_cols ].values .tolist ()
99107 num_classes = len (target_cols )
100-
108+
101109 dd = {
102110 'raw_data' : data ,
103111 'raw_target' : targets ,
104112 'num_classes' : num_classes ,
105113 'target_cols' : target_cols ,
106- 'multiclass_cnt' : multiclass_cnt if task == 'multiclass' and is_train else None
114+ 'multiclass_cnt' : (
115+ multiclass_cnt if task == 'multiclass' and is_train else None
116+ ),
107117 }
108118 if smiles_col in data .columns :
109- mask = data [smiles_col ].apply (lambda smi : self .check_smiles (smi , is_train , smi_strict ))
110- data = data [mask ]
119+ mask = data [smiles_col ].apply (
120+ lambda smi : self .check_smiles (smi , is_train , smi_strict )
121+ )
122+ data = data [mask ]
111123 dd ['smiles' ] = data [smiles_col ].tolist ()
112124 dd ['scaffolds' ] = data [smiles_col ].map (self .smi2scaffold ).tolist ()
113125 else :
@@ -127,7 +139,7 @@ def read_data(self, data=None, is_train=True, **params):
127139
128140 return dd
129141
130- def check_smiles (self ,smi , is_train , smi_strict ):
142+ def check_smiles (self , smi , is_train , smi_strict ):
131143 """
132144 Validates a SMILES string and decides whether it should be included based on training mode and strictness.
133145
@@ -144,9 +156,9 @@ def check_smiles(self,smi, is_train, smi_strict):
144156 return False
145157 else :
146158 raise ValueError (f'SMILES rule is illegal: { smi } ' )
147- return True
148-
149- def smi2scaffold (self ,smi ):
159+ return True
160+
161+ def smi2scaffold (self , smi ):
150162 """
151163 Converts a SMILES string to its corresponding scaffold.
152164
@@ -155,10 +167,12 @@ def smi2scaffold(self,smi):
155167 :return: (str) The scaffold of the SMILES string, or the original SMILES if conversion fails.
156168 """
157169 try :
158- return MurckoScaffold .MurckoScaffoldSmiles (smiles = smi , includeChirality = True )
170+ return MurckoScaffold .MurckoScaffoldSmiles (
171+ smiles = smi , includeChirality = True
172+ )
159173 except :
160174 return smi
161-
175+
162176 def anomaly_clean (self , data , task , target_cols ):
163177 """
164178 Performs anomaly cleaning on the data based on the specified task.
@@ -170,13 +184,18 @@ def anomaly_clean(self, data, task, target_cols):
170184 :return: (DataFrame) The cleaned dataset.
171185 :raises ValueError: If the provided task is not recognized.
172186 """
173- if task in ['classification' , 'multiclass' , 'multilabel_classification' , 'multilabel_regression' ]:
187+ if task in [
188+ 'classification' ,
189+ 'multiclass' ,
190+ 'multilabel_classification' ,
191+ 'multilabel_regression' ,
192+ ]:
174193 return data
175194 if task == 'regression' :
176195 return self .anomaly_clean_regression (data , target_cols )
177196 else :
178197 raise ValueError ('Unknown task: {}' .format (task ))
179-
198+
180199 def anomaly_clean_regression (self , data , target_cols ):
181200 """
182201 Performs anomaly cleaning specifically for regression tasks using a 3-sigma threshold.
@@ -189,6 +208,11 @@ def anomaly_clean_regression(self, data, target_cols):
189208 sz = data .shape [0 ]
190209 target_col = target_cols [0 ]
191210 _mean , _std = data [target_col ].mean (), data [target_col ].std ()
192- data = data [(data [target_col ] > _mean - 3 * _std ) & (data [target_col ] < _mean + 3 * _std )]
193- logger .info ('Anomaly clean with 3 sigma threshold: {} -> {}' .format (sz , data .shape [0 ]))
211+ data = data [
212+ (data [target_col ] > _mean - 3 * _std )
213+ & (data [target_col ] < _mean + 3 * _std )
214+ ]
215+ logger .info (
216+ 'Anomaly clean with 3 sigma threshold: {} -> {}' .format (sz , data .shape [0 ])
217+ )
194218 return data
0 commit comments