Skip to content

Commit 4c32766

Browse files
whiteadclaude
andcommitted
Fix scaffold split: use rarest-first ordering, fix normalization, add visualizations
- Iterate from rarest to most common scaffolds so test set contains novel chemistry - Fix normalization bug where in-place loop corrupted train stats before normalizing test - Add cells showing train vs test scaffolds and example molecules - Add rdkit.Chem.Draw import for molecule grid visualization Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 540f9c2 commit 4c32766

File tree

1 file changed

+69
-14
lines changed

1 file changed

+69
-14
lines changed

ml/regression.ipynb

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"import jax.numpy as jnp\n",
6161
"from jax.example_libraries import optimizers\n",
6262
"import jax\n",
63-
"import rdkit, rdkit.Chem\n",
63+
"import rdkit, rdkit.Chem, rdkit.Chem.Draw\n",
6464
"from rdkit.Chem.Scaffolds import MurckoScaffold\n",
6565
"import dmol"
6666
]
@@ -1643,7 +1643,7 @@
16431643
"\n",
16441644
"For molecular problems, a random train/test split can still leak structural similarity between train and test molecules. A **scaffold split** groups molecules by their core scaffold (often the Bemis-Murcko scaffold{cite}`bemis1996properties`) and puts whole scaffold groups into either train or test. This usually gives a harder but more realistic estimate of performance on novel chemistry. Scaffold splits were popularized as a standard benchmark by the MoleculeNet suite{cite}`wu2018moleculenet`.\n",
16451645
"\n",
1646-
"The quick demo below builds Murcko scaffolds from SMILES, places the most common scaffolds into the test set until we reach about 20% of molecules, and compares that error with a random split of the same size.\n"
1646+
"The standard approach places the **rarest** scaffolds into the test set, since those represent the most novel chemistry. The demo below builds Murcko scaffolds from SMILES, accumulates the least-common scaffolds into the test set until we reach about 20% of molecules, and compares that error with a random split of the same size.\n"
16471647
]
16481648
},
16491649
{
@@ -1655,8 +1655,6 @@
16551655
"smiles_col = next(\n",
16561656
" (c for c in [\"SMILES\", \"smiles\", \"CanonicalSMILES\"] if c in soldata.columns), None\n",
16571657
")\n",
1658-
"if smiles_col is None:\n",
1659-
" raise ValueError(\"No SMILES column found for scaffold split demo.\")\n",
16601658
"\n",
16611659
"\n",
16621660
"def murcko_scaffold(smiles):\n",
@@ -1674,7 +1672,8 @@
16741672
"test_target = int(0.2 * len(scaffold_data))\n",
16751673
"test_scaffolds = set()\n",
16761674
"running = 0\n",
1677-
"for scaffold, count in scaffold_counts.items():\n",
1675+
"# iterate from rarest to most common\n",
1676+
"for scaffold, count in scaffold_counts.iloc[::-1].items():\n",
16781677
" if running >= test_target:\n",
16791678
" break\n",
16801679
" test_scaffolds.add(scaffold)\n",
@@ -1683,10 +1682,10 @@
16831682
"test = scaffold_data[scaffold_data[\"Scaffold\"].isin(test_scaffolds)].copy()\n",
16841683
"train = scaffold_data[~scaffold_data[\"Scaffold\"].isin(test_scaffolds)].copy()\n",
16851684
"\n",
1686-
"for frame in [train, test]:\n",
1687-
" frame[feature_names] = (frame[feature_names] - train[feature_names].mean()) / train[\n",
1688-
" feature_names\n",
1689-
" ].std()\n",
1685+
"train_mean = train[feature_names].mean()\n",
1686+
"train_std = train[feature_names].std()\n",
1687+
"train[feature_names] = (train[feature_names] - train_mean) / train_std\n",
1688+
"test[feature_names] = (test[feature_names] - train_mean) / train_std\n",
16901689
"\n",
16911690
"x, y = train[feature_names].values, train[\"Solubility\"].values\n",
16921691
"test_x, test_y = test[feature_names].values, test[\"Solubility\"].values\n",
@@ -1701,10 +1700,10 @@
17011700
"\n",
17021701
"random_train = scaffold_data.iloc[rand_train_idx].copy()\n",
17031702
"random_test = scaffold_data.iloc[rand_test_idx].copy()\n",
1704-
"for frame in [random_train, random_test]:\n",
1705-
" frame[feature_names] = (\n",
1706-
" frame[feature_names] - random_train[feature_names].mean()\n",
1707-
" ) / random_train[feature_names].std()\n",
1703+
"rand_mean = random_train[feature_names].mean()\n",
1704+
"rand_std = random_train[feature_names].std()\n",
1705+
"random_train[feature_names] = (random_train[feature_names] - rand_mean) / rand_std\n",
1706+
"random_test[feature_names] = (random_test[feature_names] - rand_mean) / rand_std\n",
17081707
"\n",
17091708
"x, y = random_train[feature_names].values, random_train[\"Solubility\"].values\n",
17101709
"test_x, test_y = random_test[feature_names].values, random_test[\"Solubility\"].values\n",
@@ -1714,7 +1713,63 @@
17141713
"\n",
17151714
"print(f\"Scaffold split test MSE: {scaffold_mse:.2f}\")\n",
17161715
"print(f\"Random split test MSE: {random_mse:.2f}\")\n",
1717-
"print(f\"Unique scaffolds used for test: {len(test_scaffolds)}\")"
1716+
"print(f\"Unique scaffolds in test: {len(test_scaffolds)}\")\n",
1717+
"print(f\"Test molecules: {len(test)}, Train molecules: {len(train)}\")"
1718+
]
1719+
},
1720+
{
1721+
"cell_type": "code",
1722+
"execution_count": null,
1723+
"metadata": {},
1724+
"outputs": [],
1725+
"source": [
1726+
"# show some common train scaffolds and rare test scaffolds\n",
1727+
"train_scaffolds = set(scaffold_counts.index) - test_scaffolds\n",
1728+
"train_scaffold_counts = scaffold_counts[scaffold_counts.index.isin(train_scaffolds)]\n",
1729+
"test_scaffold_counts = scaffold_counts[scaffold_counts.index.isin(test_scaffolds)]\n",
1730+
"\n",
1731+
"common_train = train_scaffold_counts.head(3)\n",
1732+
"rare_test = test_scaffold_counts.sort_values(ascending=False).head(3)\n",
1733+
"\n",
1734+
"scaffold_smiles = list(common_train.index) + list(rare_test.index)\n",
1735+
"scaffold_mols = [rdkit.Chem.MolFromSmiles(s) for s in scaffold_smiles]\n",
1736+
"legends = [f\"Train (n={c})\" for c in common_train.values] + [\n",
1737+
" f\"Test (n={c})\" for c in rare_test.values\n",
1738+
"]\n",
1739+
"\n",
1740+
"# filter out any scaffolds that failed to parse\n",
1741+
"valid = [(m, l) for m, l in zip(scaffold_mols, legends) if m is not None]\n",
1742+
"scaffold_mols, legends = zip(*valid) if valid else ([], [])\n",
1743+
"\n",
1744+
"print(\"Top train scaffolds (most common) vs test scaffolds:\")\n",
1745+
"rdkit.Chem.Draw.MolsToGridImage(\n",
1746+
" scaffold_mols, molsPerRow=3, subImgSize=(250, 250), legends=list(legends)\n",
1747+
")"
1748+
]
1749+
},
1750+
{
1751+
"cell_type": "code",
1752+
"execution_count": null,
1753+
"metadata": {},
1754+
"outputs": [],
1755+
"source": [
1756+
"# show example molecules from train and test splits\n",
1757+
"train_examples = train.sample(3, random_state=42)\n",
1758+
"test_examples = test.sample(3, random_state=42)\n",
1759+
"\n",
1760+
"example_smiles = list(train_examples[smiles_col]) + list(test_examples[smiles_col])\n",
1761+
"example_labels = [\"Train\"] * len(train_examples) + [\"Test\"] * len(test_examples)\n",
1762+
"example_pairs = [\n",
1763+
" (rdkit.Chem.MolFromSmiles(s), l)\n",
1764+
" for s, l in zip(example_smiles, example_labels)\n",
1765+
" if rdkit.Chem.MolFromSmiles(s) is not None\n",
1766+
"]\n",
1767+
"example_mols, example_legends = zip(*example_pairs) if example_pairs else ([], [])\n",
1768+
"\n",
1769+
"print(\"Example molecules from each split:\")\n",
1770+
"rdkit.Chem.Draw.MolsToGridImage(\n",
1771+
" example_mols, molsPerRow=3, subImgSize=(250, 250), legends=list(example_legends)\n",
1772+
")"
17181773
]
17191774
},
17201775
{

0 commit comments

Comments
 (0)