Skip to content

Commit 382e111

Browse files
authored
Merge pull request #11 from darkkoo/main
Add tests and fix bugs
2 parents d1dc74a + 8922635 commit 382e111

33 files changed

Lines changed: 1015 additions & 89 deletions

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ python setup.py install
5353

5454
The UniMol pretrained models can be found at [dptech/Uni-Mol-Models](https://huggingface.co/dptech/Uni-Mol-Models/tree/main).
5555

56-
If the download is slow, you can use other mirrors, such as:
56+
If the download is slow, you can use a mirror, such as:
5757

5858
```bash
5959
export HF_ENDPOINT=https://hf-mirror.com
6060
```
6161

62-
Setting the `HF_ENDPOINT` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models.
62+
By default `unimol_tools` first tries the official Hugging Face endpoint. If that fails and `HF_ENDPOINT` is not set, it automatically retries using `https://hf-mirror.com`. Set `HF_ENDPOINT` yourself if you want to explicitly choose a mirror or the official site.
6363

6464
### Modify the default directory for weights
6565

docs/source/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ If the download is slow, you can use other mirrors, such as:
4242
export HF_ENDPOINT=https://hf-mirror.com
4343
```
4444

45-
Setting the `HF_ENDPOINT` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models.
45+
By default `unimol_tools` first tries the official Hugging Face endpoint. If that fails and `HF_ENDPOINT` is not set, it automatically retries using `https://hf-mirror.com`. Set `HF_ENDPOINT` to use a specific endpoint.
4646

4747
## Bohrium notebook
4848

docs/source/quickstart.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,11 @@ export MASTER_PORT='19198'
187187
Currently unimol_tools supports five types of fine-tuning tasks: `classification`, `regression`, `multiclass`, `multilabel_classification`, `multilabel_regression`.
188188

189189
The datasets used in the examples are all open source and available, including
190-
- Ames mutagenicity. The dataset includes 6512 compounds and corresponding binary labels from Ames Mutagenicity results.
191-
- ESOL (delaney) is a standard regression dataset containing structures and water solubility data for 1128 compounds.
192-
- Tox21 Data Challenge 2014 is designed to help scientists understand the potential of the chemicals and compounds being tested through the Toxicology in the 21st Century initiative to disrupt biological pathways in ways that may result in toxic effects, which includes 12 date sets. The official web site is https://tripod.nih.gov/tox21/challenge/
193-
- Solvation free energy (FreeSolv). SMILES are provided.
194-
- Vector-QM24 (VQM24) dataset. Quantum chemistry dataset of ~836 thousand small organic and inorganic molecules.
190+
- Ames mutagenicity. The dataset includes 6512 compounds and corresponding binary labels from Ames Mutagenicity results. The dataset is available at https://weilab.math.msu.edu/DataLibrary/2D/.
191+
- ESOL (delaney) is a standard regression dataset containing structures and water solubility data for 1128 compounds. The dataset is available at https://weilab.math.msu.edu/DataLibrary/2D/ and https://huggingface.co/datasets/HR-machine/ESol.
192+
- Tox21 Data Challenge 2014 is designed to help scientists understand the potential of the chemicals and compounds being tested through the Toxicology in the 21st Century initiative to disrupt biological pathways in ways that may result in toxic effects, which includes 12 date sets. The official web site is https://tripod.nih.gov/tox21/challenge/. The datasets is available at https://moleculenet.org/datasets-1 and https://www.kaggle.com/datasets/maksiamiogan/tox21-dataset.
193+
- Solvation free energy (FreeSolv). SMILES are provided. The dataset is available at https://weilab.math.msu.edu/DataLibrary/2D/.
194+
- Vector-QM24 (VQM24) dataset. Quantum chemistry dataset of ~836 thousand small organic and inorganic molecules. The dataset is available at https://zenodo.org/records/15442257.
195195

196196
### Example of classification
197197
You can use a dictionary as input. The default smiles column name is **'SMILES'** and the target column name is **'target'**. You can also customize it with `smiles_col` and `target_cols`.
@@ -411,7 +411,7 @@ predictor = MolPredict(load_model='./exp')
411411
pred = predictor.predict(test_df_dict['smiles'])
412412
```
413413

414-
It also supports directly using the sdf file path as input.
414+
It also supports directly using the sdf file path as input. The following example reads it in advance due to preprocessing missing values.
415415

416416
```python
417417
from unimol_tools import MolTrain, MolPredict

docs/source/weight.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ If the download is slow, you can use other mirrors, such as:
2222
2323
export HF_ENDPOINT=https://hf-mirror.com
2424
25-
Setting the ``HF_ENDPOINT`` environment variable specifies the mirror address for the Hugging Face Hub to use when downloading models.
25+
By default ``unimol_tools`` first tries the official Hugging Face endpoint. If that fails and ``HF_ENDPOINT`` is not set, it automatically retries with ``https://hf-mirror.com``. Set the variable yourself to choose a specific endpoint.
2626

2727
`unimol_tools.weights.weight_hub.py <https://github.com/deepmodeling/unimol_tools/blob/main/unimol_tools/weights/weighthub.py>`_ control the logger.
2828

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
numpy==1.22.4
2-
pandas==1.4.0
3-
scikit-learn==1.5.0
4-
torch
1+
numpy>=2.0.0
2+
pandas>=2.2.2
3+
scikit-learn>=1.5.0
4+
torch>=2.4.0
55
joblib
6-
rdkit
6+
rdkit>=2024.3.4
77
pyyaml
88
addict
99
tqdm

setup.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@
2323
],
2424
),
2525
install_requires=[
26-
"numpy<2.0.0,>=1.22.4",
27-
"pandas<2.0.0",
28-
"torch",
26+
"numpy<2.3.0,>=2.0.0",
27+
"pandas>=2.2.2",
28+
"torch>=2.4.0",
2929
"joblib",
30-
"rdkit",
30+
"rdkit>=2024.3.4",
3131
"pyyaml",
3232
"addict",
33-
"scikit-learn",
33+
"scikit-learn>=1.5.0",
3434
"numba",
3535
"tqdm",
3636
],
37-
python_requires=">=3.6",
37+
python_requires=">=3.9",
3838
include_package_data=True,
3939
classifiers=[
4040
"Development Status :: 5 - Production/Stable",

tests/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import pytest
3+
4+
@pytest.fixture(scope="session", autouse=True)
5+
def set_unimol_weight_dir(tmp_path_factory):
6+
"""Ensure UNIMOL_WEIGHT_DIR is set to a temporary directory for tests."""
7+
weight_dir = tmp_path_factory.mktemp("weights")
8+
original = os.environ.get("UNIMOL_WEIGHT_DIR")
9+
os.environ["UNIMOL_WEIGHT_DIR"] = str(weight_dir)
10+
yield
11+
if original is None:
12+
os.environ.pop("UNIMOL_WEIGHT_DIR", None)
13+
else:
14+
os.environ["UNIMOL_WEIGHT_DIR"] = original
15+
16+
17+
def pytest_addoption(parser):
18+
parser.addoption("--run-network", action="store_true", help="run tests that need network")
19+
20+
21+
def pytest_collection_modifyitems(config, items):
22+
if config.getoption("--run-network"):
23+
return
24+
skip_marker = pytest.mark.skip(reason="need --run-network to run")
25+
for item in items:
26+
if "network" in item.keywords:
27+
item.add_marker(skip_marker)

tests/test_classification.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import zipfile
3+
import pandas as pd
4+
import pytest
5+
from utils_net import download_for_test
6+
7+
from unimol_tools import MolTrain, MolPredict
8+
9+
DATA_URL = 'https://weilab.math.msu.edu/DataLibrary/2D/Downloads/Ames_smi.zip'
10+
11+
12+
@pytest.mark.network
13+
def test_classification_train_predict(tmp_path):
14+
# ensure any pretrained weights are written to a temporary directory
15+
os.environ.setdefault('UNIMOL_WEIGHT_DIR', str(tmp_path / 'weights'))
16+
zip_path = tmp_path / 'Ames_smi.zip'
17+
download_for_test(
18+
DATA_URL,
19+
zip_path,
20+
timeout=(5, 60),
21+
max_retries=5,
22+
backoff_factor=0.5,
23+
allow_resume=True,
24+
skip_on_failure=True,
25+
)
26+
with zipfile.ZipFile(zip_path, 'r') as zf:
27+
zf.extractall(tmp_path)
28+
csv_path = tmp_path / 'Ames.csv'
29+
if not csv_path.exists():
30+
pytest.skip('Dataset missing after extraction')
31+
df = pd.read_csv(csv_path)
32+
df = df.drop(columns=['CAS_NO']).rename(columns={'Activity': 'target'})
33+
# take 100 samples for testing
34+
df = df.sample(n=100, random_state=42)
35+
train_df = df.sample(frac=0.8, random_state=42)
36+
test_df = df.drop(train_df.index)
37+
train_data = train_df.to_dict(orient='list')
38+
test_smiles = test_df['Canonical_Smiles'].tolist()
39+
40+
exp_dir = tmp_path / 'exp'
41+
clf = MolTrain(
42+
task='classification',
43+
data_type='molecule',
44+
epochs=1,
45+
batch_size=2,
46+
kfold=2,
47+
metrics='auc',
48+
smiles_col='Canonical_Smiles',
49+
save_path=str(exp_dir),
50+
)
51+
try:
52+
clf.fit(train_data)
53+
except Exception as e:
54+
pytest.skip(f"Training failed: {e}")
55+
56+
predictor = MolPredict(load_model=str(exp_dir))
57+
preds = predictor.predict(test_smiles)
58+
assert len(preds) == len(test_smiles)

tests/test_conformer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
from unimol_tools.data.conformer import (
3+
inner_coords,
4+
coords2unimol,
5+
inner_smi2coords,
6+
create_mol_from_atoms_and_coords,
7+
)
8+
from unimol_tools.data.dictionary import Dictionary
9+
10+
11+
def test_inner_coords_and_coords2unimol():
12+
atoms = ['C', 'H', 'O']
13+
coords = [[0, 0, 0], [0, 0, 1], [1, 0, 0]]
14+
no_h_atoms, no_h_coords = inner_coords(atoms, coords, remove_hs=True)
15+
assert 'H' not in no_h_atoms
16+
d = Dictionary()
17+
for a in ['C', 'O']:
18+
if a not in d:
19+
d.add_symbol(a)
20+
feat = coords2unimol(no_h_atoms, no_h_coords, d)
21+
assert feat['src_tokens'].dtype == int
22+
assert feat['src_coord'].shape[1] == 3
23+
24+
25+
def test_inner_smi2coords_returns_mol():
26+
mol = inner_smi2coords('CC', return_mol=True)
27+
from rdkit.Chem import Mol
28+
29+
assert isinstance(mol, Mol)
30+
31+
32+
def test_create_mol_from_atoms_and_coords():
33+
atoms = ['C', 'O']
34+
coords = [[0, 0, 0], [1, 0, 0]]
35+
mol = create_mol_from_atoms_and_coords(atoms, coords)
36+
from rdkit.Chem import Mol
37+
38+
assert isinstance(mol, Mol)
39+
assert mol.GetNumAtoms() == 2

tests/test_datareader.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pandas as pd
2+
import numpy as np
3+
import pytest
4+
from unimol_tools.data.datareader import MolDataReader
5+
6+
7+
def test_read_data_from_smiles_list():
8+
smiles = ["CCO", "C"]
9+
reader = MolDataReader()
10+
result = reader.read_data(smiles)
11+
assert result["smiles"] == smiles
12+
assert len(result["scaffolds"]) == len(smiles)
13+
assert result["raw_data"].shape[0] == len(smiles)
14+
15+
16+
def test_check_smiles_behavior():
17+
reader = MolDataReader()
18+
# invalid SMILES should return False during training when not strict
19+
assert reader.check_smiles("invalid", is_train=True, smi_strict=False) is False
20+
# invalid SMILES should raise in strict mode
21+
with pytest.raises(ValueError):
22+
reader.check_smiles("invalid", is_train=True, smi_strict=True)
23+
24+
25+
def test_convert_numeric_columns():
26+
from rdkit import Chem
27+
df = pd.DataFrame({
28+
"ROMol": [Chem.MolFromSmiles("CCO")],
29+
"num": ["1"],
30+
"alpha": ["a"],
31+
})
32+
reader = MolDataReader()
33+
out = reader._convert_numeric_columns(df.copy())
34+
assert pd.api.types.is_numeric_dtype(out["num"])
35+
assert not pd.api.types.is_numeric_dtype(out["alpha"])
36+
assert out["ROMol"].iloc[0] == df["ROMol"].iloc[0]
37+
38+
39+
def test_anomaly_clean_regression():
40+
df = pd.DataFrame({
41+
"SMILES": ["C"] * 11,
42+
"TARGET": [1] * 10 + [100],
43+
})
44+
reader = MolDataReader()
45+
cleaned = reader.anomaly_clean_regression(df, ["TARGET"])
46+
assert 100 not in cleaned["TARGET"].values
47+
assert len(cleaned) == 10

0 commit comments

Comments
 (0)