Skip to content

Commit 4f627da

Browse files
authored
Yn ddp1 (#326)
* [style] highlight warning logs * [feat] using linear head for classification head * [chore] update log file name * [perf] Add dropout for classification task. Set dropout=0 for regression * [feat] DDP for train * [style] isort + black * [fix] ddp predict not get correct order and count * [fix] DDP will cause redundant log files * [feat] add ddp api for moltrain: use_ddp * [fix] UniMolRepr api changed, from use_gpu to use_cuda, to adapt coming DDP for predict and repr module; * [feat] set default skewed scaler to log1p
1 parent 781bf59 commit 4f627da

30 files changed

+1793
-867
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .train import MolTrain
21
from .predict import MolPredict
3-
from .predictor import UniMolRepr
2+
from .predictor import UniMolRepr
3+
from .train import MolTrain
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .model_config import MODEL_CONFIG, MODEL_CONFIG_V2
1+
from .model_config import MODEL_CONFIG, MODEL_CONFIG_V2

unimol_tools/unimol_tools/config/default.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ learning_rate: 1e-4
1717
warmup_ratio: 0.03
1818
batch_size: 16
1919
max_norm: 5.0
20-
cuda: True
21-
amp: True
20+
use_cuda: True
21+
use_amp: True
22+
use_ddp: True
23+
use_gpu: 0, 1

unimol_tools/unimol_tools/config/model_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
MODEL_CONFIG = {
2-
"weight":{
2+
"weight": {
33
"protein": "poc_pre_220816.pt",
44
"molecule_no_h": "mol_pre_no_h_220816.pt",
55
"molecule_all_h": "mol_pre_all_h_220816.pt",
66
"crystal": "mp_all_h_230313.pt",
77
"oled": "oled_pre_no_h_230101.pt",
88
},
9-
"dict":{
9+
"dict": {
1010
"protein": "poc.dict.txt",
1111
"molecule_no_h": "mol.dict.txt",
1212
"molecule_all_h": "mol.dict.txt",
@@ -23,4 +23,4 @@
2323
'570m': 'modelzoo/570M/checkpoint.pt',
2424
'1.1B': 'modelzoo/1.1B/checkpoint.pt',
2525
},
26-
}
26+
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .datahub import DataHub
2-
from .dictionary import Dictionary
2+
from .dictionary import Dictionary

unimol_tools/unimol_tools/data/conformer.py

Lines changed: 184 additions & 78 deletions
Large diffs are not rendered by default.

unimol_tools/unimol_tools/data/datahub.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
# LICENSE file in the root directory of this source tree.
44

55
from __future__ import absolute_import, division, print_function
6+
67
import numpy as np
8+
9+
from ..utils import logger
10+
from .conformer import ConformerGen, UniMolV2Feature
711
from .datareader import MolDataReader
812
from .datascaler import TargetScaler
9-
from .conformer import ConformerGen, UniMolV2Feature
1013
from .split import Splitter
11-
from ..utils import logger
1214

1315

1416
class DataHub(object):
1517
"""
1618
The DataHub class is responsible for storing and preprocessing data for machine learning tasks.
17-
It initializes with configuration options to handle different types of tasks such as regression,
19+
It initializes with configuration options to handle different types of tasks such as regression,
1820
classification, and others. It also supports data scaling and handling molecular data.
1921
"""
22+
2023
def __init__(self, data=None, is_train=True, save_path=None, **params):
2124
"""
2225
Initializes the DataHub instance with data and configuration for the ML task.
@@ -35,44 +38,54 @@ def __init__(self, data=None, is_train=True, save_path=None, **params):
3538
self.ss_method = params.get('target_normalize', 'none')
3639
self._init_data(**params)
3740
self._init_split(**params)
38-
41+
3942
def _init_data(self, **params):
4043
"""
4144
Initializes and preprocesses the data based on the task and parameters provided.
4245
43-
This method handles reading raw data, scaling targets, and transforming data for use with
44-
molecular inputs. It tailors the preprocessing steps based on the task type, such as regression
46+
This method handles reading raw data, scaling targets, and transforming data for use with
47+
molecular inputs. It tailors the preprocessing steps based on the task type, such as regression
4548
or classification.
4649
4750
:param params: Additional parameters for data processing.
4851
:raises ValueError: If the task type is unknown.
4952
"""
5053
self.data = MolDataReader().read_data(self.data, self.is_train, **params)
51-
self.data['target_scaler'] = TargetScaler(self.ss_method, self.task, self.save_path)
52-
if self.task == 'regression':
53-
target = np.array(self.data['raw_target']).reshape(-1,1).astype(np.float32)
54+
self.data['target_scaler'] = TargetScaler(
55+
self.ss_method, self.task, self.save_path
56+
)
57+
if self.task == 'regression':
58+
target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.float32)
5459
if self.is_train:
5560
self.data['target_scaler'].fit(target, self.save_path)
5661
self.data['target'] = self.data['target_scaler'].transform(target)
5762
else:
5863
self.data['target'] = target
5964
elif self.task == 'classification':
60-
target = np.array(self.data['raw_target']).reshape(-1,1).astype(np.int32)
65+
target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32)
6166
self.data['target'] = target
62-
elif self.task =='multiclass':
63-
target = np.array(self.data['raw_target']).reshape(-1,1).astype(np.int32)
67+
elif self.task == 'multiclass':
68+
target = np.array(self.data['raw_target']).reshape(-1, 1).astype(np.int32)
6469
self.data['target'] = target
6570
if not self.is_train:
66-
self.data['multiclass_cnt'] = self.multiclass_cnt
71+
self.data['multiclass_cnt'] = self.multiclass_cnt
6772
elif self.task == 'multilabel_regression':
68-
target = np.array(self.data['raw_target']).reshape(-1,self.data['num_classes']).astype(np.float32)
73+
target = (
74+
np.array(self.data['raw_target'])
75+
.reshape(-1, self.data['num_classes'])
76+
.astype(np.float32)
77+
)
6978
if self.is_train:
7079
self.data['target_scaler'].fit(target, self.save_path)
71-
self.data['target'] = self.data['target_scaler'].transform(target)
80+
self.data['target'] = self.data['target_scaler'].transform(target)
7281
else:
7382
self.data['target'] = target
7483
elif self.task == 'multilabel_classification':
75-
target = np.array(self.data['raw_target']).reshape(-1,self.data['num_classes']).astype(np.int32)
84+
target = (
85+
np.array(self.data['raw_target'])
86+
.reshape(-1, self.data['num_classes'])
87+
.astype(np.int32)
88+
)
7689
self.data['target'] = target
7790
elif self.task == 'repr':
7891
self.data['target'] = self.data['raw_target']
@@ -81,23 +94,30 @@ def _init_data(self, **params):
8194

8295
if params.get('model_name', None) == 'unimolv1':
8396
if 'atoms' in self.data and 'coordinates' in self.data:
84-
no_h_list = ConformerGen(**params).transform_raw(self.data['atoms'], self.data['coordinates'])
97+
no_h_list = ConformerGen(**params).transform_raw(
98+
self.data['atoms'], self.data['coordinates']
99+
)
85100
else:
86-
smiles_list = self.data["smiles"]
101+
smiles_list = self.data["smiles"]
87102
no_h_list = ConformerGen(**params).transform(smiles_list)
88103
elif params.get('model_name', None) == 'unimolv2':
89104
if 'atoms' in self.data and 'coordinates' in self.data:
90-
no_h_list = UniMolV2Feature(**params).transform_raw(self.data['atoms'], self.data['coordinates'])
105+
no_h_list = UniMolV2Feature(**params).transform_raw(
106+
self.data['atoms'], self.data['coordinates']
107+
)
91108
else:
92-
smiles_list = self.data["smiles"]
109+
smiles_list = self.data["smiles"]
93110
no_h_list = UniMolV2Feature(**params).transform(smiles_list)
94111

95112
self.data['unimol_input'] = no_h_list
96113

97114
def _init_split(self, **params):
98115

99-
self.split_method = params.get('split_method','5fold_random')
100-
kfold, method = int(self.split_method.split('fold')[0]), self.split_method.split('_')[-1] # Nfold_xxxx
116+
self.split_method = params.get('split_method', '5fold_random')
117+
kfold, method = (
118+
int(self.split_method.split('fold')[0]),
119+
self.split_method.split('_')[-1],
120+
) # Nfold_xxxx
101121
self.kfold = params.get('kfold', kfold)
102122
self.method = params.get('split', method)
103123
self.split_seed = params.get('split_seed', 42)
@@ -110,8 +130,8 @@ def _init_split(self, **params):
110130
logger.info(f"Kfold is 1, all data is used for training.")
111131
else:
112132
logger.info(f"Split method: {self.method}, fold: {self.kfold}")
113-
nfolds = np.zeros(len(split_nfolds[0][0])+len(split_nfolds[0][1]), dtype=int)
133+
nfolds = np.zeros(len(split_nfolds[0][0]) + len(split_nfolds[0][1]), dtype=int)
114134
for enu, (tr_idx, te_idx) in enumerate(split_nfolds):
115135
nfolds[te_idx] = enu
116136
self.data['split_nfolds'] = split_nfolds
117-
return split_nfolds
137+
return split_nfolds

unimol_tools/unimol_tools/data/datareader.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@
55
from __future__ import absolute_import, division, print_function
66

77
import os
8-
import pandas as pd
8+
import pathlib
9+
910
import numpy as np
11+
import pandas as pd
1012
from rdkit import Chem
11-
from ..utils import logger
12-
import pathlib
1313
from rdkit.Chem.Scaffolds import MurckoScaffold
1414

15+
from ..utils import logger
16+
17+
1518
class MolDataReader(object):
1619
'''A class to read Mol Data.'''
20+
1721
def read_data(self, data=None, is_train=True, **params):
18-
# TO DO
22+
# TO DO
1923
# 1. add anomaly detection & outlier removal.
2024
# 2. add support for other file format.
2125
# 3. add support for multi tasks.
@@ -26,7 +30,7 @@ def read_data(self, data=None, is_train=True, **params):
2630
1. if target_cols is not None, use target_cols as target columns.
2731
2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns.
2832
3. use given target_cols as target columns placeholder with value -1.0 for predict
29-
33+
3034
:param data: The input molecular data. Can be a file path (str), a dictionary, or a list of SMILES strings.
3135
:param is_train: (bool) A flag indicating if the operation is for training. Determines data processing steps.
3236
:param params: A dictionary of additional parameters for data processing.
@@ -50,21 +54,21 @@ def read_data(self, data=None, is_train=True, **params):
5054
# load from dict
5155
if 'target' in data:
5256
label = np.array(data['target'])
53-
if len(label.shape)==1 or label.shape[1] == 1:
57+
if len(label.shape) == 1 or label.shape[1] == 1:
5458
data[target_col_prefix] = label.reshape(-1)
5559
else:
5660
for i in range(label.shape[1]):
57-
data[target_col_prefix + str(i)] = label[:,i]
61+
data[target_col_prefix + str(i)] = label[:, i]
5862

5963
_ = data.pop('target', None)
6064
data = pd.DataFrame(data).rename(columns={smiles_col: 'SMILES'})
61-
65+
6266
elif isinstance(data, list) or isinstance(data, np.ndarray):
6367
# load from smiles list
6468
data = pd.DataFrame(data, columns=['SMILES'])
6569
else:
6670
raise ValueError('Unknown data type: {}'.format(type(data)))
67-
71+
6872
#### parsing target columns
6973
#### 1. if target_cols is not None, use target_cols as target columns.
7074
#### 2. if target_cols is None, use all columns with prefix 'target_col_prefix' as target columns.
@@ -77,37 +81,45 @@ def read_data(self, data=None, is_train=True, **params):
7781
multiclass_cnt = None
7882
else:
7983
if target_cols is None:
80-
target_cols = [item for item in data.columns if item.startswith(target_col_prefix)]
84+
target_cols = [
85+
item for item in data.columns if item.startswith(target_col_prefix)
86+
]
8187
elif isinstance(target_cols, str):
8288
target_cols = target_cols.split(',')
8389
elif isinstance(target_cols, list):
8490
pass
8591
else:
86-
raise ValueError('Unknown target_cols type: {}'.format(type(target_cols)))
87-
92+
raise ValueError(
93+
'Unknown target_cols type: {}'.format(type(target_cols))
94+
)
95+
8896
if is_train:
8997
if anomaly_clean:
90-
data = self.anomaly_clean(data, task, target_cols)
98+
data = self.anomaly_clean(data, task, target_cols)
9199
if task == 'multiclass':
92100
multiclass_cnt = int(data[target_cols].max() + 1)
93101
else:
94102
for col in target_cols:
95103
if col not in data.columns or data[col].isnull().any():
96104
data[col] = -1.0
97-
105+
98106
targets = data[target_cols].values.tolist()
99107
num_classes = len(target_cols)
100-
108+
101109
dd = {
102110
'raw_data': data,
103111
'raw_target': targets,
104112
'num_classes': num_classes,
105113
'target_cols': target_cols,
106-
'multiclass_cnt': multiclass_cnt if task == 'multiclass' and is_train else None
114+
'multiclass_cnt': (
115+
multiclass_cnt if task == 'multiclass' and is_train else None
116+
),
107117
}
108118
if smiles_col in data.columns:
109-
mask = data[smiles_col].apply(lambda smi: self.check_smiles(smi, is_train, smi_strict))
110-
data = data[mask]
119+
mask = data[smiles_col].apply(
120+
lambda smi: self.check_smiles(smi, is_train, smi_strict)
121+
)
122+
data = data[mask]
111123
dd['smiles'] = data[smiles_col].tolist()
112124
dd['scaffolds'] = data[smiles_col].map(self.smi2scaffold).tolist()
113125
else:
@@ -127,7 +139,7 @@ def read_data(self, data=None, is_train=True, **params):
127139

128140
return dd
129141

130-
def check_smiles(self,smi, is_train, smi_strict):
142+
def check_smiles(self, smi, is_train, smi_strict):
131143
"""
132144
Validates a SMILES string and decides whether it should be included based on training mode and strictness.
133145
@@ -144,9 +156,9 @@ def check_smiles(self,smi, is_train, smi_strict):
144156
return False
145157
else:
146158
raise ValueError(f'SMILES rule is illegal: {smi}')
147-
return True
148-
149-
def smi2scaffold(self,smi):
159+
return True
160+
161+
def smi2scaffold(self, smi):
150162
"""
151163
Converts a SMILES string to its corresponding scaffold.
152164
@@ -155,10 +167,12 @@ def smi2scaffold(self,smi):
155167
:return: (str) The scaffold of the SMILES string, or the original SMILES if conversion fails.
156168
"""
157169
try:
158-
return MurckoScaffold.MurckoScaffoldSmiles(smiles=smi, includeChirality=True)
170+
return MurckoScaffold.MurckoScaffoldSmiles(
171+
smiles=smi, includeChirality=True
172+
)
159173
except:
160174
return smi
161-
175+
162176
def anomaly_clean(self, data, task, target_cols):
163177
"""
164178
Performs anomaly cleaning on the data based on the specified task.
@@ -170,13 +184,18 @@ def anomaly_clean(self, data, task, target_cols):
170184
:return: (DataFrame) The cleaned dataset.
171185
:raises ValueError: If the provided task is not recognized.
172186
"""
173-
if task in ['classification', 'multiclass', 'multilabel_classification', 'multilabel_regression']:
187+
if task in [
188+
'classification',
189+
'multiclass',
190+
'multilabel_classification',
191+
'multilabel_regression',
192+
]:
174193
return data
175194
if task == 'regression':
176195
return self.anomaly_clean_regression(data, target_cols)
177196
else:
178197
raise ValueError('Unknown task: {}'.format(task))
179-
198+
180199
def anomaly_clean_regression(self, data, target_cols):
181200
"""
182201
Performs anomaly cleaning specifically for regression tasks using a 3-sigma threshold.
@@ -189,6 +208,11 @@ def anomaly_clean_regression(self, data, target_cols):
189208
sz = data.shape[0]
190209
target_col = target_cols[0]
191210
_mean, _std = data[target_col].mean(), data[target_col].std()
192-
data = data[(data[target_col] > _mean - 3 * _std) & (data[target_col] < _mean + 3 * _std)]
193-
logger.info('Anomaly clean with 3 sigma threshold: {} -> {}'.format(sz, data.shape[0]))
211+
data = data[
212+
(data[target_col] > _mean - 3 * _std)
213+
& (data[target_col] < _mean + 3 * _std)
214+
]
215+
logger.info(
216+
'Anomaly clean with 3 sigma threshold: {} -> {}'.format(sz, data.shape[0])
217+
)
194218
return data

0 commit comments

Comments
 (0)