Skip to content

Commit 57c03a3

Browse files
authored
Merge pull request #13 from deepmodeling/pretrain
This PR introduces lightweight pretraining functionality to Unimol_tools and adds Hydra-based command-line interfaces. It includes a complete pretraining pipeline with masking, loss functions, metrics aggregation, and distributed training support. The PR also updates the CLI tools with configuration management and adds pretrained model path parameters. Adds comprehensive pretraining infrastructure including model architecture, loss functions, dataset handling, and trainer implementation Integrates Hydra configuration management for command-line tools Introduces support for custom pretrained model and dictionary paths across training, representation, and prediction workflows
2 parents 382e111 + 9408949 commit 57c03a3

34 files changed

Lines changed: 3262 additions & 59 deletions

README.md

Lines changed: 144 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ Unimol_tools is a easy-use wrappers for property prediction,representation and d
1919

2020
## Install
2121
- pytorch is required, please install pytorch according to your environment. if you are using cuda, please install pytorch with cuda. More details can be found at https://pytorch.org/get-started/locally/
22-
- currently, rdkit needs with numpy<2.0.0, please install rdkit with numpy<2.0.0.
2322

2423
### Option 1: Installing from PyPi (Recommended, for stable version)
2524

@@ -38,13 +37,13 @@ pip install huggingface_hub
3837
### Option 2: Installing from source (for latest version)
3938

4039
```python
41-
## Dependencies installation
42-
pip install -r requirements.txt
43-
4440
## Clone repository
4541
git clone https://github.com/deepmodeling/unimol_tools.git
4642
cd unimol_tools
4743

44+
## Dependencies installation
45+
pip install -r requirements.txt
46+
4847
## Install
4948
python setup.py install
5049
```
@@ -53,6 +52,10 @@ python setup.py install
5352

5453
The UniMol pretrained models can be found at [dptech/Uni-Mol-Models](https://huggingface.co/dptech/Uni-Mol-Models/tree/main).
5554

55+
If ``pretrained_model_path`` or ``pretrained_dict_path`` are left as ``None`` the
56+
toolkit will automatically download the corresponding files from this
57+
Hugging Face repository at runtime.
58+
5659
If the download is slow, you can use a mirror, such as:
5760

5861
```bash
@@ -70,6 +73,7 @@ export UNIMOL_WEIGHT_DIR=/path/to/your/weights/dir/
7073
```
7174

7275
## News
76+
- 2025-09-22: Lightweight pre-training tools are now available in Unimol_tools!
7377
- 2025-05-26: Unimol_tools is now independent from the Uni-Mol repository!
7478
- 2025-03-28: Unimol_tools now support Distributed Data Parallel (DDP)!
7579
- 2024-11-22: Unimol V2 has been added to Unimol_tools!
@@ -82,15 +86,40 @@ export UNIMOL_WEIGHT_DIR=/path/to/your/weights/dir/
8286
### Molecule property prediction
8387
```python
8488
from unimol_tools import MolTrain, MolPredict
85-
clf = MolTrain(task='classification',
86-
data_type='molecule',
87-
epochs=10,
88-
batch_size=16,
89-
metrics='auc',
90-
)
89+
clf = MolTrain(
90+
task='classification',
91+
data_type='molecule',
92+
epochs=10,
93+
batch_size=16,
94+
metrics='auc',
95+
# pretrained weights are downloaded automatically when left as ``None``
96+
# pretrained_model_path='/path/to/checkpoint.ckpt',
97+
# pretrained_dict_path='/path/to/dict.txt',
98+
)
9199
clf.fit(data = train_data)
92-
# currently support data with smiles based csv/txt file, and
93-
# custom dict of {'atoms':[['C','C'],['C','H','O']], 'coordinates':[coordinates_1,coordinates_2]}
100+
# currently support data with smiles based csv/txt file, and sdf file with mol,
101+
# and custom dict of {'atoms':[['C','C'],['C','H','O']], 'coordinates':[coordinates_1,coordinates_2]}
102+
103+
# The dict format can refer to the following format, or be obtained from sdf,
104+
# which can also be directly input into the model.
105+
train_sdf = PandasTools.LoadSDF('exp/unimol_conformers_train.sdf')
106+
train_dict = {
107+
'atoms': [list(atom.GetSymbol() for atom in mol.GetAtoms()) for mol in train_sdf['ROMol']],
108+
# atoms[0]: ['C', 'C', 'O', 'C', 'O', 'C', ...]
109+
'coordinates': [mol.GetConformers()[0].GetPositions() for mol in train_sdf['ROMol']],
110+
# coordinates[0]: array([[ 6.6462, -1.8268, 1.9275],
111+
# [ 6.1552, -1.9367, 0.4873],
112+
# [ 5.1832, -0.8757, 0.3007],
113+
# [ 5.4651, -0.0272, -0.7266],
114+
# [ 4.8586, -0.0844, -1.7917],
115+
# [ 6.5362, 0.9767, -0.3742],
116+
# ...,])
117+
'TARGET': train_sdf['TARGET'].tolist()
118+
# TARGET: [0, 1, 0, 0, 1, 0, ...]
119+
}
120+
# clf.fit(data = train_sdf)
121+
# clf.fit(data = train_dict)
122+
94123

95124
clf = MolPredict(load_model='../exp')
96125
res = clf.predict(data = test_data)
@@ -99,8 +128,14 @@ res = clf.predict(data = test_data)
99128
```python
100129
import numpy as np
101130
from unimol_tools import UniMolRepr
102-
# single smiles unimol representation
103-
clf = UniMolRepr(data_type='molecule', remove_hs=False)
131+
# single SMILES UniMol representation. If no paths are provided the
132+
# pretrained model and dictionary are fetched from Hugging Face.
133+
clf = UniMolRepr(
134+
data_type='molecule',
135+
remove_hs=False,
136+
# pretrained_model_path='/path/to/checkpoint.ckpt',
137+
# pretrained_dict_path='/path/to/dict.txt',
138+
)
104139
smiles = 'c1ccc(cc1)C2=NCC(=O)Nc3c2cc(cc3)[N+](=O)[O]'
105140
smiles_list = [smiles]
106141
unimol_repr = clf.get_repr(smiles_list, return_atomic_reprs=True)
@@ -110,6 +145,101 @@ print(np.array(unimol_repr['cls_repr']).shape)
110145
print(np.array(unimol_repr['atomic_reprs']).shape)
111146
```
112147

148+
### Command-line utilities
149+
150+
Hydra-powered entry points make training, prediction, and representation
151+
available from the command line. Key-value pairs override options from the
152+
YAML files in `unimol_tools/config`.
153+
154+
#### Training
155+
```bash
156+
python -m unimol_tools.cli.run_train \
157+
train_path=train.csv \
158+
task=regression \
159+
save_path=./exp \
160+
smiles_col=smiles \
161+
target_cols=[target1] \
162+
epochs=10 \
163+
learning_rate=1e-4 \
164+
batch_size=16 \
165+
kfold=5
166+
```
167+
168+
#### Prediction
169+
```bash
170+
python -m unimol_tools.cli.run_predict load_model=./exp data_path=test.csv
171+
```
172+
173+
#### Representation
174+
```bash
175+
python -m unimol_tools.cli.run_repr data_path=test.csv smiles_col=smiles
176+
```
177+
178+
### Molecule pretraining
179+
180+
`unimol_tools` provides a command-line utility for pretraining Uni-Mol models on
181+
your own dataset. The script uses
182+
[Hydra](https://hydra.cc/) so configuration values can be overridden at the
183+
command line. Two common invocation examples are shown below: one for LMDB data
184+
and one for a CSV of SMILES strings.
185+
186+
#### LMDB dataset
187+
188+
```bash
189+
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
190+
export HYDRA_FULL_ERROR=1
191+
export OMP_NUM_THREADS=1
192+
193+
torchrun --standalone --nproc_per_node=NUM_GPUS \
194+
-m unimol_tools.cli.run_pretrain \
195+
dataset.train_path=train.lmdb \
196+
dataset.valid_path=valid.lmdb \
197+
dataset.data_type=lmdb \
198+
dataset.dict_path=dict.txt \
199+
training.total_steps=1000000 \
200+
training.batch_size=16 \
201+
training.update_freq=1
202+
```
203+
204+
`dataset.dict_path` is optional. The effective batch size is
205+
`n_gpu * training.batch_size * training.update_freq`.
206+
207+
#### CSV dataset
208+
209+
```bash
210+
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
211+
export HYDRA_FULL_ERROR=1
212+
export OMP_NUM_THREADS=1
213+
214+
torchrun --standalone --nproc_per_node=NUM_GPUS \
215+
-m unimol_tools.cli.run_pretrain \
216+
dataset.train_path=train.csv \
217+
dataset.valid_path=valid.csv \
218+
dataset.data_type=csv \
219+
dataset.smiles_column=smiles \
220+
training.total_steps=1000000 \
221+
training.batch_size=16 \
222+
training.update_freq=1
223+
```
224+
225+
For multi-node training, specify additional arguments, for example:
226+
227+
```bash
228+
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
229+
export HYDRA_FULL_ERROR=1
230+
export OMP_NUM_THREADS=1
231+
232+
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \
233+
--master_addr=<master-ip> --master_port=<port> \
234+
-m unimol_tools.cli.run_pretrain ...
235+
```
236+
237+
All available options are defined in
238+
[`pretrain_config.py`](unimol_tools/pretrain/pretrain_config.py), and checkpoints
239+
along with the dictionary are saved to the run directory. When GPU memory is
240+
limited, increase `training.update_freq` to accumulate gradients while keeping
241+
the effective batch size `n_gpu * training.batch_size * training.update_freq`.
242+
113243
## Credits
114244
We thanks all contributors from the community for their suggestions, bug reports and chemistry advices. Currently unimol-tools is maintained by Yaning Cui, Xiaohong Ji, Zhifeng Gao from DP Technology and AI for Science Insitution, Beijing.
115245

docs/source/quickstart.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,36 @@ pred = clf.fit(data = train_data)
3939
clf = MolPredict(load_model='../exp')
4040
res = clf.predict(data = test_data)
4141
```
42+
43+
### Command-line utilities
44+
45+
Training, prediction, and representation can also be launched from the
46+
command line by overriding options in the YAML config files.
47+
48+
#### Training
49+
```bash
50+
python -m unimol_tools.cli.run_train \
51+
train_path=train.csv \
52+
task=regression \
53+
save_path=./exp \
54+
smiles_col=smiles \
55+
target_cols=[target1] \
56+
epochs=10 \
57+
learning_rate=1e-4 \
58+
batch_size=16 \
59+
kfold=5
60+
```
61+
62+
#### Prediction
63+
```bash
64+
python -m unimol_tools.cli.run_predict load_model=./exp data_path=test.csv
65+
```
66+
67+
#### Representation
68+
```bash
69+
python -m unimol_tools.cli.run_repr data_path=test.csv smiles_col=smiles
70+
```
71+
4272
## Uni-Mol molecule and atoms level representation
4373

4474
Uni-Mol representation can easily be achieved as follow.
@@ -60,6 +90,51 @@ print(np.array(unimol_repr['cls_repr']).shape)
6090
# atomic level repr, align with rdkit mol.GetAtoms()
6191
print(np.array(unimol_repr['atomic_reprs']).shape)
6292
```
93+
## Molecule pretraining
94+
95+
Uni-Mol can be pretrained from scratch using the ``run_pretrain`` utility. The
96+
script is driven by Hydra, so configuration options are supplied on the command
97+
line. The examples below demonstrate common setups for LMDB and CSV inputs.
98+
99+
### LMDB dataset
100+
101+
```bash
102+
torchrun --standalone --nproc_per_node=NUM_GPUS \
103+
-m unimol_tools.cli.run_pretrain \
104+
dataset.train_path=train.lmdb \
105+
dataset.valid_path=valid.lmdb \
106+
dataset.data_type=lmdb \
107+
dataset.dict_path=dict.txt \
108+
training.total_steps=10000 \
109+
training.batch_size=16 \
110+
training.update_freq=1
111+
```
112+
113+
`dataset.dict_path` is optional. The effective batch size is
114+
`n_gpu * training.batch_size * training.update_freq`.
115+
116+
### CSV dataset
117+
118+
```bash
119+
torchrun --standalone --nproc_per_node=NUM_GPUS \
120+
-m unimol_tools.cli.run_pretrain \
121+
dataset.train_path=train.csv \
122+
dataset.valid_path=valid.csv \
123+
dataset.data_type=csv \
124+
dataset.smiles_column=smiles \
125+
training.total_steps=10000 \
126+
training.batch_size=16 \
127+
training.update_freq=1
128+
```
129+
130+
To scale across multiple machines, include the appropriate `torchrun`
131+
arguments, e.g. `--nnodes`, `--node_rank`, `--master_addr` and
132+
`--master_port`.
133+
134+
Checkpoints and the dictionary are written to the output directory. When GPU
135+
memory is limited, increase `training.update_freq` to accumulate gradients while
136+
keeping the effective batch size `n_gpu * training.batch_size * training.update_freq`.
137+
63138
## Continue training (Re-train)
64139

65140
```python

requirements.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
numpy<2.3.0
12
numpy>=2.0.0
23
pandas>=2.2.2
34
scikit-learn>=1.5.0
@@ -6,4 +7,8 @@ joblib
67
rdkit>=2024.3.4
78
pyyaml
89
addict
9-
tqdm
10+
tqdm
11+
hydra-core
12+
omegaconf
13+
tensorboard
14+
lmdb

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="unimol_tools",
8-
version="0.1.4.post1",
8+
version="0.1.5",
99
description=(
1010
"unimol_tools is a Python package for property prediction with Uni-Mol in molecule, materials and protein."
1111
),
@@ -33,6 +33,10 @@
3333
"scikit-learn>=1.5.0",
3434
"numba",
3535
"tqdm",
36+
"hydra-core",
37+
"omegaconf",
38+
"tensorboard",
39+
"lmdb",
3640
],
3741
python_requires=">=3.9",
3842
include_package_data=True,

tests/test_multilabel_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from unimol_tools import MolTrain, MolPredict
1010

1111
CSV_URL = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz'
12-
SDF_URL = 'https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_data_allsdf&sec='
12+
SDF_URL = 'https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_data_allsdf'
1313

1414

1515
@pytest.mark.network

tests/test_representation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
VQM24_URL = 'https://zenodo.org/records/15442257/files/DMC.npz?download=1'
1212
TOX21_CSV_URL = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz'
13-
TOX21_SDF_URL = 'https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_data_allsdf&sec='
13+
TOX21_SDF_URL = 'https://tripod.nih.gov/tox21/challenge/download?id=tox21_10k_data_allsdf'
1414

1515

1616
@pytest.mark.network

unimol_tools/cli/__init__.py

Whitespace-only changes.

unimol_tools/cli/run_predict.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import hydra
2+
from omegaconf import DictConfig
3+
4+
from ..predict import MolPredict
5+
6+
7+
@hydra.main(version_base=None, config_path="../config", config_name="predict_config")
8+
def main(cfg: DictConfig):
9+
data_path = cfg.get("data_path")
10+
predictor = MolPredict(cfg=cfg)
11+
predictor.predict(data=data_path, save_path=cfg.get("save_path"))
12+
13+
14+
if __name__ == "__main__":
15+
main()

0 commit comments

Comments
 (0)