Skip to content

Commit af111e1

Browse files
authored
Merge pull request #88 from aqlaboratory/feature/update-valset-rebase
Feature/update valset rebase
2 parents b9d336b + 42f362d commit af111e1

File tree

8 files changed

+497
-365
lines changed

8 files changed

+497
-365
lines changed

openfold3/core/data/pipelines/preprocessing/caches/pdb_val.py

Lines changed: 232 additions & 181 deletions
Large diffs are not rendered by default.

openfold3/core/data/primitives/caches/clustering.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
get_all_cache_chains,
2727
logger,
2828
)
29-
from openfold3.core.data.primitives.caches.format import ClusteredDatasetCache
29+
from openfold3.core.data.primitives.caches.format import (
30+
ClusteredDatasetCache,
31+
)
3032
from openfold3.core.data.resources.residues import MoleculeType
3133

3234

@@ -236,6 +238,10 @@ def add_cluster_data(
236238
# Increment cluster size
237239
cluster_id_to_size[interface_cluster_id] += 1
238240

241+
# TODO: debugging-only, remove
242+
if "UNKNOWN" in cluster_id_to_size:
243+
logger.warning(f"Cluster ID 'UNKNOWN' has size {cluster_id_to_size['UNKNOWN']}")
244+
239245
# Add cluster sizes
240246
if add_sizes:
241247
for metadata in structure_cache.values():

openfold3/core/data/primitives/caches/filtering.py

Lines changed: 147 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""All operations for processing and manipulating metadata and training caches."""
1616

1717
import functools
18+
import itertools
1819
import logging
1920
import random
2021
from collections import defaultdict
@@ -1157,9 +1158,6 @@ def subsample_chains_by_type(
11571158
Follows AF3 SI 5.8 Monomer Selection Step 4). The function subsamples specific
11581159
chains and deletes all other chains from the cache.
11591160
1160-
Note that proteins are sampled as unique cluster representatives, which is not
1161-
directly stated in the SI but seems logical given that chains are preclustered.
1162-
11631161
Args:
11641162
dataset_cache (ClusteredDatasetCache):
11651163
The cache to subsample.
@@ -1180,48 +1178,45 @@ def subsample_chains_by_type(
11801178
if random_seed is not None:
11811179
random.seed(random_seed)
11821180

1183-
# Store the chain data points grouped by cluster
1184-
chain_type_to_clusters = {
1185-
MoleculeType.PROTEIN: defaultdict(list),
1186-
MoleculeType.DNA: defaultdict(list),
1187-
MoleculeType.RNA: defaultdict(list),
1181+
# Store the chain data points grouped by molecule type
1182+
chain_type_to_datapoints = {
1183+
MoleculeType.PROTEIN: [],
1184+
MoleculeType.DNA: [],
1185+
MoleculeType.RNA: [],
11881186
}
11891187
chain_type_to_n_samples = {
11901188
MoleculeType.PROTEIN: n_protein,
11911189
MoleculeType.DNA: n_dna,
11921190
MoleculeType.RNA: n_rna,
11931191
}
11941192

1195-
# Collect all chain data points of the specified types grouped by cluster
1193+
# Sort the chain datapoints into their respective molecule type lists
11961194
for pdb_id, structure_data in dataset_cache.structure_data.items():
11971195
for chain_id, chain_data in structure_data.chains.items():
11981196
chain_type = chain_data.molecule_type
11991197

1200-
if chain_type not in chain_type_to_clusters:
1198+
if chain_type not in chain_type_to_datapoints:
12011199
continue
12021200

1203-
chain_type_to_clusters[chain_type][chain_data.cluster_id].append(
1201+
chain_type_to_datapoints[chain_type].append(
12041202
ChainDataPoint(pdb_id, chain_id)
12051203
)
12061204

1205+
# Subsample the chains, taking N per type except if the count is set to None in
1206+
# which case all samples are taken
12071207
keep_chain_datapoints = set()
12081208

1209-
# Subsample the chains, taking one per cluster except if the count is set to None in
1210-
# which case all samples are taken
1211-
for chain_type, clusters in chain_type_to_clusters.items():
1209+
for chain_type, datapoints in chain_type_to_datapoints.items():
12121210
n_samples = chain_type_to_n_samples[chain_type]
12131211

12141212
# Take every single datapoint if n_samples is None
12151213
if n_samples is None:
1216-
for chain_datapoints in clusters.values():
1217-
keep_chain_datapoints.update(chain_datapoints)
1214+
keep_chain_datapoints.update(datapoints)
12181215

1219-
# Otherwise, take 1 sample from n_samples clusters
1216+
# Otherwise, take N random samples
12201217
else:
1221-
sampled_clusters = random.sample(list(clusters.keys()), n_samples)
1222-
1223-
for cluster_id in sampled_clusters:
1224-
keep_chain_datapoints.add(random.choice(clusters[cluster_id]))
1218+
sampled_datapoints = random.sample(datapoints, n_samples)
1219+
keep_chain_datapoints.update(sampled_datapoints)
12251220

12261221
# Remove everything outside of the selected chains
12271222
filter_cache_to_specified_chains(dataset_cache, keep_chain_datapoints)
@@ -1676,70 +1671,150 @@ def filter_chains_by_metric_eligibility(
16761671
return structure_data
16771672

16781673

1674+
def select_one_per_cluster(
1675+
datapoints: list[ChainDataPoint | InterfaceDataPoint],
1676+
cache: ValidationDatasetCache,
1677+
random_seed: int | None = None,
1678+
) -> list[ChainDataPoint] | list[InterfaceDataPoint]:
1679+
"""Selects one random datapoint per cluster from the provided list.
1680+
1681+
Args:
1682+
datapoints: List of chain or interface datapoints.
1683+
cache: The cache to look up cluster IDs from.
1684+
random_seed: Random seed for reproducibility.
1685+
1686+
Returns:
1687+
Filtered list with one datapoint per cluster.
1688+
"""
1689+
if random_seed is not None:
1690+
random.seed(random_seed)
1691+
1692+
# Group by cluster_id
1693+
cluster_to_datapoints = defaultdict(list)
1694+
1695+
for dp in datapoints:
1696+
if isinstance(dp, ChainDataPoint):
1697+
cluster_id = cache.structure_data[dp.pdb_id].chains[dp.chain_id].cluster_id
1698+
else: # InterfaceDataPoint
1699+
cluster_id = (
1700+
cache.structure_data[dp.pdb_id].interfaces[dp.interface_id].cluster_id
1701+
)
1702+
cluster_to_datapoints[cluster_id].append(dp)
1703+
1704+
# Pick one per cluster
1705+
return [random.choice(dps) for dps in cluster_to_datapoints.values()]
1706+
1707+
16791708
def select_final_validation_data(
1680-
unfiltered_cache: ValidationDatasetCache,
1681-
monomer_structure_data: dict[str, ValidationDatasetStructureData],
1682-
multimer_structure_data: dict[str, ValidationDatasetStructureData],
1709+
val_dataset_cache: ValidationDatasetCache,
1710+
selected_chains: list[ChainDataPoint],
1711+
selected_interfaces: list[InterfaceDataPoint],
1712+
random_seed: int | None = None,
16831713
) -> None:
1684-
"""Selects the final targets and marks chains/interfaces to score on.
1714+
"""Subsets cache and marks cluster representatives with priority.
16851715
1686-
This will create the final validation dataset cache by subsetting the unfiltered
1687-
cache only to the relevant PDB-IDs, and then turning on the use_metrics flag only
1688-
for select chains and interfaces coming out of the multimer and monomer sets. Note
1689-
that we are not scoring validation metrics on all low-homology chains and interfaces
1690-
of each target in the final validation set, but only those that are part of the
1691-
selected monomer and multimer sets.
1716+
1. Subsets cache to PDB IDs from selected chains/interfaces
1717+
2. Marks selected chains/interfaces as cluster representatives
1718+
3. For remaining clusters, picks additional representatives from
1719+
metric-eligible chains/interfaces in the subsetted PDBs
1720+
4. Enables use_metrics for all representatives
16921721
16931722
Args:
1694-
unfiltered_cache: ValClusteredDatasetCache
1695-
Preliminary validation dataset cache corresponding to the full proto
1696-
validation set, after the initial time- and token-based filtering.
1697-
monomer_structure_data: dict[str, ValClusteredDatasetStructureData]
1698-
The monomer set of SI 5.8, containing the subsampled low-homology polymer
1699-
chains and metric-eligible low-homology interfaces.
1700-
multimer_structure_data: dict[str, ValClusteredDatasetStructureData]
1701-
The multimer set of SI 5.8, containing subsampled low-homology interfaces
1702-
and their constituent chains.
1703-
1723+
val_dataset_cache: The full validation dataset cache.
1724+
selected_chains: Chain datapoints from monomer selection.
1725+
selected_interfaces: Interface datapoints from multimer selection.
1726+
random_seed: Random seed for reproducibility when picking additional
1727+
representatives.
17041728
17051729
Returns:
1706-
None, the filtered_structure_data is updated in-place.
1730+
None, modifies val_dataset_cache in place.
17071731
"""
1708-
# First subset the unfiltered cache to only the relevant PDB-IDs
1709-
relevant_pdb_ids = set(monomer_structure_data.keys()) | set(
1710-
multimer_structure_data.keys()
1711-
)
1712-
structure_data = unfiltered_cache.structure_data
1713-
structure_data = {pdb_id: structure_data[pdb_id] for pdb_id in relevant_pdb_ids}
1714-
1715-
for pdb_id, structure_data_entry in structure_data.items():
1716-
# Go through the monomer and multimer sets sequentially
1717-
for set_name, set_structure_data in zip(
1718-
("monomer", "multimer"),
1719-
(
1720-
monomer_structure_data,
1721-
multimer_structure_data,
1722-
),
1723-
strict=True,
1724-
):
1725-
if pdb_id not in set_structure_data:
1726-
continue
1732+
if random_seed is not None:
1733+
random.seed(random_seed)
1734+
1735+
# Step 1: Collect PDB IDs and subset cache
1736+
monomer_pdb_ids = {dp.pdb_id for dp in selected_chains}
1737+
multimer_pdb_ids = {dp.pdb_id for dp in selected_interfaces}
1738+
selected_pdb_ids = monomer_pdb_ids | multimer_pdb_ids
1739+
1740+
val_dataset_cache.structure_data = {
1741+
pdb_id: data
1742+
for pdb_id, data in val_dataset_cache.structure_data.items()
1743+
if pdb_id in selected_pdb_ids
1744+
}
1745+
1746+
# Step 2: Track which subset each PDB came from for logging purposes
1747+
pdb_id_to_source = {
1748+
pdb_id: "monomer" if pdb_id in monomer_pdb_ids else "multimer"
1749+
for pdb_id in selected_pdb_ids
1750+
}
1751+
1752+
# Step 3: Turn on metrics for selected chains/interfaces and track which clusters
1753+
# are already represented
1754+
represented_clusters: set[str] = set()
1755+
1756+
# Iterate through all chains and interfaces
1757+
for dp in itertools.chain(selected_chains, selected_interfaces):
1758+
pdb_id = dp.pdb_id
1759+
structure_data = val_dataset_cache.structure_data[pdb_id]
1760+
1761+
if isinstance(dp, ChainDataPoint):
1762+
dp_data = structure_data.chains[dp.chain_id]
1763+
else:
1764+
dp_data = structure_data.interfaces[dp.interface_id]
1765+
1766+
# Turn on metrics and mark the cluster as represented
1767+
dp_data.use_metrics = True
1768+
represented_clusters.add(dp_data.cluster_id)
1769+
1770+
# Store which subset the PDB came from
1771+
dp_data.source_subset = pdb_id_to_source[pdb_id]
17271772

1728-
# Activate metrics for all chains in the monomer/multimer sets
1729-
for chain_id in set_structure_data[pdb_id].chains:
1730-
structure_data_entry.chains[chain_id].use_metrics = True
1773+
# Step 4: Collect chains/interfaces in the PDBs that are metric-eligible and not yet
1774+
# "represented"
1775+
non_represented_dps: list[ChainDataPoint | InterfaceDataPoint] = []
1776+
1777+
for pdb_id, structure_data in val_dataset_cache.structure_data.items():
1778+
for dp_id, metadata in itertools.chain(
1779+
structure_data.chains.items(), structure_data.interfaces.items()
1780+
):
1781+
if (
1782+
metadata.metric_eligible
1783+
and metadata.cluster_id not in represented_clusters
1784+
):
1785+
if isinstance(metadata, ValidationDatasetChainData):
1786+
non_represented_dps.append(
1787+
ChainDataPoint(
1788+
pdb_id=pdb_id,
1789+
chain_id=dp_id,
1790+
)
1791+
)
1792+
else:
1793+
non_represented_dps.append(
1794+
InterfaceDataPoint(
1795+
pdb_id=pdb_id,
1796+
interface_id=dp_id,
1797+
)
1798+
)
17311799

1732-
# Add this for logging purposes
1733-
structure_data_entry.chains[chain_id].source_subset = set_name
1800+
# Step 5: Pick random representatives from non-represented clusters
1801+
subsampled_dps = select_one_per_cluster(
1802+
non_represented_dps, val_dataset_cache, random_seed=random_seed
1803+
)
17341804

1735-
# Activate metrics for all interfaces in the monomer/multimer sets
1736-
for interface_id in set_structure_data[pdb_id].interfaces:
1737-
structure_data_entry.interfaces[interface_id].use_metrics = True
1805+
# Step 6: Mark selected representatives to use metrics as well
1806+
for dp in subsampled_dps:
1807+
pdb_id = dp.pdb_id
1808+
structure_data = val_dataset_cache.structure_data[pdb_id]
17381809

1739-
# Add this for logging purposes
1740-
structure_data_entry.interfaces[interface_id].source_subset = set_name
1810+
if isinstance(dp, ChainDataPoint):
1811+
dp_data = structure_data.chains[dp.chain_id]
1812+
else:
1813+
dp_data = structure_data.interfaces[dp.interface_id]
17411814

1742-
unfiltered_cache.structure_data = structure_data
1815+
# Turn on metrics and store source of original PDB-ID
1816+
dp_data.use_metrics = True
1817+
dp_data.source_subset = pdb_id_to_source[pdb_id]
17431818

17441819

17451820
def filter_only_ligand_ligand_metrics(

0 commit comments

Comments
 (0)