|
60 | 60 | "import jax.numpy as jnp\n", |
61 | 61 | "from jax.example_libraries import optimizers\n", |
62 | 62 | "import jax\n", |
63 | | - "import rdkit, rdkit.Chem\n", |
| 63 | + "import rdkit, rdkit.Chem, rdkit.Chem.Draw\n", |
64 | 64 | "from rdkit.Chem.Scaffolds import MurckoScaffold\n", |
65 | 65 | "import dmol" |
66 | 66 | ] |
|
1643 | 1643 | "\n", |
1644 | 1644 | "For molecular problems, a random train/test split can still leak structural similarity between train and test molecules. A **scaffold split** groups molecules by their core scaffold (often the Bemis-Murcko scaffold{cite}`bemis1996properties`) and puts whole scaffold groups into either train or test. This usually gives a harder but more realistic estimate of performance on novel chemistry. Scaffold splits were popularized as a standard benchmark by the MoleculeNet suite{cite}`wu2018moleculenet`.\n", |
1645 | 1645 | "\n", |
1646 | | - "The quick demo below builds Murcko scaffolds from SMILES, places the most common scaffolds into the test set until we reach about 20% of molecules, and compares that error with a random split of the same size.\n" |
| 1646 | + "The standard approach places the **rarest** scaffolds into the test set, since those represent the most novel chemistry. The demo below builds Murcko scaffolds from SMILES, accumulates the least-common scaffolds into the test set until we reach about 20% of molecules, and compares that error with a random split of the same size.\n" |
1647 | 1647 | ] |
1648 | 1648 | }, |
1649 | 1649 | { |
|
1655 | 1655 | "smiles_col = next(\n", |
1656 | 1656 | " (c for c in [\"SMILES\", \"smiles\", \"CanonicalSMILES\"] if c in soldata.columns), None\n", |
1657 | 1657 | ")\n", |
1658 | | - "if smiles_col is None:\n", |
1659 | | - " raise ValueError(\"No SMILES column found for scaffold split demo.\")\n", |
1660 | 1658 | "\n", |
1661 | 1659 | "\n", |
1662 | 1660 | "def murcko_scaffold(smiles):\n", |
|
1674 | 1672 | "test_target = int(0.2 * len(scaffold_data))\n", |
1675 | 1673 | "test_scaffolds = set()\n", |
1676 | 1674 | "running = 0\n", |
1677 | | - "for scaffold, count in scaffold_counts.items():\n", |
| 1675 | + "# iterate from rarest to most common\n", |
| 1676 | + "for scaffold, count in scaffold_counts.iloc[::-1].items():\n", |
1678 | 1677 | " if running >= test_target:\n", |
1679 | 1678 | " break\n", |
1680 | 1679 | " test_scaffolds.add(scaffold)\n", |
|
1683 | 1682 | "test = scaffold_data[scaffold_data[\"Scaffold\"].isin(test_scaffolds)].copy()\n", |
1684 | 1683 | "train = scaffold_data[~scaffold_data[\"Scaffold\"].isin(test_scaffolds)].copy()\n", |
1685 | 1684 | "\n", |
1686 | | - "for frame in [train, test]:\n", |
1687 | | - " frame[feature_names] = (frame[feature_names] - train[feature_names].mean()) / train[\n", |
1688 | | - " feature_names\n", |
1689 | | - " ].std()\n", |
| 1685 | + "train_mean = train[feature_names].mean()\n", |
| 1686 | + "train_std = train[feature_names].std()\n", |
| 1687 | + "train[feature_names] = (train[feature_names] - train_mean) / train_std\n", |
| 1688 | + "test[feature_names] = (test[feature_names] - train_mean) / train_std\n", |
1690 | 1689 | "\n", |
1691 | 1690 | "x, y = train[feature_names].values, train[\"Solubility\"].values\n", |
1692 | 1691 | "test_x, test_y = test[feature_names].values, test[\"Solubility\"].values\n", |
|
1701 | 1700 | "\n", |
1702 | 1701 | "random_train = scaffold_data.iloc[rand_train_idx].copy()\n", |
1703 | 1702 | "random_test = scaffold_data.iloc[rand_test_idx].copy()\n", |
1704 | | - "for frame in [random_train, random_test]:\n", |
1705 | | - " frame[feature_names] = (\n", |
1706 | | - " frame[feature_names] - random_train[feature_names].mean()\n", |
1707 | | - " ) / random_train[feature_names].std()\n", |
| 1703 | + "rand_mean = random_train[feature_names].mean()\n", |
| 1704 | + "rand_std = random_train[feature_names].std()\n", |
| 1705 | + "random_train[feature_names] = (random_train[feature_names] - rand_mean) / rand_std\n", |
| 1706 | + "random_test[feature_names] = (random_test[feature_names] - rand_mean) / rand_std\n", |
1708 | 1707 | "\n", |
1709 | 1708 | "x, y = random_train[feature_names].values, random_train[\"Solubility\"].values\n", |
1710 | 1709 | "test_x, test_y = random_test[feature_names].values, random_test[\"Solubility\"].values\n", |
|
1714 | 1713 | "\n", |
1715 | 1714 | "print(f\"Scaffold split test MSE: {scaffold_mse:.2f}\")\n", |
1716 | 1715 | "print(f\"Random split test MSE: {random_mse:.2f}\")\n", |
1717 | | - "print(f\"Unique scaffolds used for test: {len(test_scaffolds)}\")" |
| 1716 | + "print(f\"Unique scaffolds in test: {len(test_scaffolds)}\")\n", |
| 1717 | + "print(f\"Test molecules: {len(test)}, Train molecules: {len(train)}\")" |
| 1718 | + ] |
| 1719 | + }, |
| 1720 | + { |
| 1721 | + "cell_type": "code", |
| 1722 | + "execution_count": null, |
| 1723 | + "metadata": {}, |
| 1724 | + "outputs": [], |
| 1725 | + "source": [ |
| 1726 | + "# show some common train scaffolds and rare test scaffolds\n", |
| 1727 | + "train_scaffolds = set(scaffold_counts.index) - test_scaffolds\n", |
| 1728 | + "train_scaffold_counts = scaffold_counts[scaffold_counts.index.isin(train_scaffolds)]\n", |
| 1729 | + "test_scaffold_counts = scaffold_counts[scaffold_counts.index.isin(test_scaffolds)]\n", |
| 1730 | + "\n", |
| 1731 | + "common_train = train_scaffold_counts.head(3)\n", |
| 1732 | + "rare_test = test_scaffold_counts.sort_values(ascending=False).head(3)\n", |
| 1733 | + "\n", |
| 1734 | + "scaffold_smiles = list(common_train.index) + list(rare_test.index)\n", |
| 1735 | + "scaffold_mols = [rdkit.Chem.MolFromSmiles(s) for s in scaffold_smiles]\n", |
| 1736 | + "legends = [f\"Train (n={c})\" for c in common_train.values] + [\n", |
| 1737 | + " f\"Test (n={c})\" for c in rare_test.values\n", |
| 1738 | + "]\n", |
| 1739 | + "\n", |
| 1740 | + "# filter out any scaffolds that failed to parse\n", |
| 1741 | + "valid = [(m, l) for m, l in zip(scaffold_mols, legends) if m is not None]\n", |
| 1742 | + "scaffold_mols, legends = zip(*valid) if valid else ([], [])\n", |
| 1743 | + "\n", |
| 1744 | + "print(\"Top train scaffolds (most common) vs test scaffolds:\")\n", |
| 1745 | + "rdkit.Chem.Draw.MolsToGridImage(\n", |
| 1746 | + " scaffold_mols, molsPerRow=3, subImgSize=(250, 250), legends=list(legends)\n", |
| 1747 | + ")" |
| 1748 | + ] |
| 1749 | + }, |
| 1750 | + { |
| 1751 | + "cell_type": "code", |
| 1752 | + "execution_count": null, |
| 1753 | + "metadata": {}, |
| 1754 | + "outputs": [], |
| 1755 | + "source": [ |
| 1756 | + "# show example molecules from train and test splits\n", |
| 1757 | + "train_examples = train.sample(3, random_state=42)\n", |
| 1758 | + "test_examples = test.sample(3, random_state=42)\n", |
| 1759 | + "\n", |
| 1760 | + "example_smiles = list(train_examples[smiles_col]) + list(test_examples[smiles_col])\n", |
| 1761 | + "example_labels = [\"Train\"] * len(train_examples) + [\"Test\"] * len(test_examples)\n", |
| 1762 | + "example_pairs = [\n", |
| 1763 | + " (rdkit.Chem.MolFromSmiles(s), l)\n", |
| 1764 | + " for s, l in zip(example_smiles, example_labels)\n", |
| 1765 | + " if rdkit.Chem.MolFromSmiles(s) is not None\n", |
| 1766 | + "]\n", |
| 1767 | + "example_mols, example_legends = zip(*example_pairs) if example_pairs else ([], [])\n", |
| 1768 | + "\n", |
| 1769 | + "print(\"Example molecules from each split:\")\n", |
| 1770 | + "rdkit.Chem.Draw.MolsToGridImage(\n", |
| 1771 | + " example_mols, molsPerRow=3, subImgSize=(250, 250), legends=list(example_legends)\n", |
| 1772 | + ")" |
1718 | 1773 | ] |
1719 | 1774 | }, |
1720 | 1775 | { |
|
0 commit comments