1515"""All operations for processing and manipulating metadata and training caches."""
1616
1717import functools
18+ import itertools
1819import logging
1920import random
2021from collections import defaultdict
@@ -1157,9 +1158,6 @@ def subsample_chains_by_type(
11571158 Follows AF3 SI 5.8 Monomer Selection Step 4). The function subsamples specific
11581159 chains and deletes all other chains from the cache.
11591160
1160- Note that proteins are sampled as unique cluster representatives, which is not
1161- directly stated in the SI but seems logical given that chains are preclustered.
1162-
11631161 Args:
11641162 dataset_cache (ClusteredDatasetCache):
11651163 The cache to subsample.
@@ -1180,48 +1178,45 @@ def subsample_chains_by_type(
11801178 if random_seed is not None :
11811179 random .seed (random_seed )
11821180
1183- # Store the chain data points grouped by cluster
1184- chain_type_to_clusters = {
1185- MoleculeType .PROTEIN : defaultdict ( list ) ,
1186- MoleculeType .DNA : defaultdict ( list ) ,
1187- MoleculeType .RNA : defaultdict ( list ) ,
1181+ # Store the chain data points grouped by molecule type
1182+ chain_type_to_datapoints = {
1183+ MoleculeType .PROTEIN : [] ,
1184+ MoleculeType .DNA : [] ,
1185+ MoleculeType .RNA : [] ,
11881186 }
11891187 chain_type_to_n_samples = {
11901188 MoleculeType .PROTEIN : n_protein ,
11911189 MoleculeType .DNA : n_dna ,
11921190 MoleculeType .RNA : n_rna ,
11931191 }
11941192
1195- # Collect all chain data points of the specified types grouped by cluster
1193+ # Sort the chain datapoints into their respective molecule type lists
11961194 for pdb_id , structure_data in dataset_cache .structure_data .items ():
11971195 for chain_id , chain_data in structure_data .chains .items ():
11981196 chain_type = chain_data .molecule_type
11991197
1200- if chain_type not in chain_type_to_clusters :
1198+ if chain_type not in chain_type_to_datapoints :
12011199 continue
12021200
1203- chain_type_to_clusters [chain_type ][ chain_data . cluster_id ].append (
1201+ chain_type_to_datapoints [chain_type ].append (
12041202 ChainDataPoint (pdb_id , chain_id )
12051203 )
12061204
1205+ # Subsample the chains, taking N per type except if the count is set to None in
1206+ # which case all samples are taken
12071207 keep_chain_datapoints = set ()
12081208
1209- # Subsample the chains, taking one per cluster except if the count is set to None in
1210- # which case all samples are taken
1211- for chain_type , clusters in chain_type_to_clusters .items ():
1209+ for chain_type , datapoints in chain_type_to_datapoints .items ():
12121210 n_samples = chain_type_to_n_samples [chain_type ]
12131211
12141212 # Take every single datapoint if n_samples is None
12151213 if n_samples is None :
1216- for chain_datapoints in clusters .values ():
1217- keep_chain_datapoints .update (chain_datapoints )
1214+ keep_chain_datapoints .update (datapoints )
12181215
1219- # Otherwise, take 1 sample from n_samples clusters
1216+ # Otherwise, take N random samples
12201217 else :
1221- sampled_clusters = random .sample (list (clusters .keys ()), n_samples )
1222-
1223- for cluster_id in sampled_clusters :
1224- keep_chain_datapoints .add (random .choice (clusters [cluster_id ]))
1218+ sampled_datapoints = random .sample (datapoints , n_samples )
1219+ keep_chain_datapoints .update (sampled_datapoints )
12251220
12261221 # Remove everything outside of the selected chains
12271222 filter_cache_to_specified_chains (dataset_cache , keep_chain_datapoints )
@@ -1676,70 +1671,150 @@ def filter_chains_by_metric_eligibility(
16761671 return structure_data
16771672
16781673
1674+ def select_one_per_cluster (
1675+ datapoints : list [ChainDataPoint | InterfaceDataPoint ],
1676+ cache : ValidationDatasetCache ,
1677+ random_seed : int | None = None ,
1678+ ) -> list [ChainDataPoint ] | list [InterfaceDataPoint ]:
1679+ """Selects one random datapoint per cluster from the provided list.
1680+
1681+ Args:
1682+ datapoints: List of chain or interface datapoints.
1683+ cache: The cache to look up cluster IDs from.
1684+ random_seed: Random seed for reproducibility.
1685+
1686+ Returns:
1687+ Filtered list with one datapoint per cluster.
1688+ """
1689+ if random_seed is not None :
1690+ random .seed (random_seed )
1691+
1692+ # Group by cluster_id
1693+ cluster_to_datapoints = defaultdict (list )
1694+
1695+ for dp in datapoints :
1696+ if isinstance (dp , ChainDataPoint ):
1697+ cluster_id = cache .structure_data [dp .pdb_id ].chains [dp .chain_id ].cluster_id
1698+ else : # InterfaceDataPoint
1699+ cluster_id = (
1700+ cache .structure_data [dp .pdb_id ].interfaces [dp .interface_id ].cluster_id
1701+ )
1702+ cluster_to_datapoints [cluster_id ].append (dp )
1703+
1704+ # Pick one per cluster
1705+ return [random .choice (dps ) for dps in cluster_to_datapoints .values ()]
1706+
1707+
16791708def select_final_validation_data (
1680- unfiltered_cache : ValidationDatasetCache ,
1681- monomer_structure_data : dict [str , ValidationDatasetStructureData ],
1682- multimer_structure_data : dict [str , ValidationDatasetStructureData ],
1709+ val_dataset_cache : ValidationDatasetCache ,
1710+ selected_chains : list [ChainDataPoint ],
1711+ selected_interfaces : list [InterfaceDataPoint ],
1712+ random_seed : int | None = None ,
16831713) -> None :
1684- """Selects the final targets and marks chains/interfaces to score on .
1714+ """Subsets cache and marks cluster representatives with priority .
16851715
1686- This will create the final validation dataset cache by subsetting the unfiltered
1687- cache only to the relevant PDB-IDs, and then turning on the use_metrics flag only
1688- for select chains and interfaces coming out of the multimer and monomer sets. Note
1689- that we are not scoring validation metrics on all low-homology chains and interfaces
1690- of each target in the final validation set, but only those that are part of the
1691- selected monomer and multimer sets.
1716+ 1. Subsets cache to PDB IDs from selected chains/interfaces
1717+ 2. Marks selected chains/interfaces as cluster representatives
1718+ 3. For remaining clusters, picks additional representatives from
1719+ metric-eligible chains/interfaces in the subsetted PDBs
1720+ 4. Enables use_metrics for all representatives
16921721
16931722 Args:
1694- unfiltered_cache: ValClusteredDatasetCache
1695- Preliminary validation dataset cache corresponding to the full proto
1696- validation set, after the initial time- and token-based filtering.
1697- monomer_structure_data: dict[str, ValClusteredDatasetStructureData]
1698- The monomer set of SI 5.8, containing the subsampled low-homology polymer
1699- chains and metric-eligible low-homology interfaces.
1700- multimer_structure_data: dict[str, ValClusteredDatasetStructureData]
1701- The multimer set of SI 5.8, containing subsampled low-homology interfaces
1702- and their constituent chains.
1703-
1723+ val_dataset_cache: The full validation dataset cache.
1724+ selected_chains: Chain datapoints from monomer selection.
1725+ selected_interfaces: Interface datapoints from multimer selection.
1726+ random_seed: Random seed for reproducibility when picking additional
1727+ representatives.
17041728
17051729 Returns:
1706- None, the filtered_structure_data is updated in- place.
1730+ None, modifies val_dataset_cache in place.
17071731 """
1708- # First subset the unfiltered cache to only the relevant PDB-IDs
1709- relevant_pdb_ids = set (monomer_structure_data .keys ()) | set (
1710- multimer_structure_data .keys ()
1711- )
1712- structure_data = unfiltered_cache .structure_data
1713- structure_data = {pdb_id : structure_data [pdb_id ] for pdb_id in relevant_pdb_ids }
1714-
1715- for pdb_id , structure_data_entry in structure_data .items ():
1716- # Go through the monomer and multimer sets sequentially
1717- for set_name , set_structure_data in zip (
1718- ("monomer" , "multimer" ),
1719- (
1720- monomer_structure_data ,
1721- multimer_structure_data ,
1722- ),
1723- strict = True ,
1724- ):
1725- if pdb_id not in set_structure_data :
1726- continue
1732+ if random_seed is not None :
1733+ random .seed (random_seed )
1734+
1735+ # Step 1: Collect PDB IDs and subset cache
1736+ monomer_pdb_ids = {dp .pdb_id for dp in selected_chains }
1737+ multimer_pdb_ids = {dp .pdb_id for dp in selected_interfaces }
1738+ selected_pdb_ids = monomer_pdb_ids | multimer_pdb_ids
1739+
1740+ val_dataset_cache .structure_data = {
1741+ pdb_id : data
1742+ for pdb_id , data in val_dataset_cache .structure_data .items ()
1743+ if pdb_id in selected_pdb_ids
1744+ }
1745+
1746+ # Step 2: Track which subset each PDB came from for logging purposes
1747+ pdb_id_to_source = {
1748+ pdb_id : "monomer" if pdb_id in monomer_pdb_ids else "multimer"
1749+ for pdb_id in selected_pdb_ids
1750+ }
1751+
1752+ # Step 3: Turn on metrics for selected chains/interfaces and track which clusters
1753+ # are already represented
1754+ represented_clusters : set [str ] = set ()
1755+
1756+ # Iterate through all chains and interfaces
1757+ for dp in itertools .chain (selected_chains , selected_interfaces ):
1758+ pdb_id = dp .pdb_id
1759+ structure_data = val_dataset_cache .structure_data [pdb_id ]
1760+
1761+ if isinstance (dp , ChainDataPoint ):
1762+ dp_data = structure_data .chains [dp .chain_id ]
1763+ else :
1764+ dp_data = structure_data .interfaces [dp .interface_id ]
1765+
1766+ # Turn on metrics and mark the cluster as represented
1767+ dp_data .use_metrics = True
1768+ represented_clusters .add (dp_data .cluster_id )
1769+
1770+ # Store which subset the PDB came from
1771+ dp_data .source_subset = pdb_id_to_source [pdb_id ]
17271772
1728- # Activate metrics for all chains in the monomer/multimer sets
1729- for chain_id in set_structure_data [pdb_id ].chains :
1730- structure_data_entry .chains [chain_id ].use_metrics = True
1773+ # Step 4: Collect chains/interfaces in the PDBs that are metric-eligible and not yet
1774+ # "represented"
1775+ non_represented_dps : list [ChainDataPoint | InterfaceDataPoint ] = []
1776+
1777+ for pdb_id , structure_data in val_dataset_cache .structure_data .items ():
1778+ for dp_id , metadata in itertools .chain (
1779+ structure_data .chains .items (), structure_data .interfaces .items ()
1780+ ):
1781+ if (
1782+ metadata .metric_eligible
1783+ and metadata .cluster_id not in represented_clusters
1784+ ):
1785+ if isinstance (metadata , ValidationDatasetChainData ):
1786+ non_represented_dps .append (
1787+ ChainDataPoint (
1788+ pdb_id = pdb_id ,
1789+ chain_id = dp_id ,
1790+ )
1791+ )
1792+ else :
1793+ non_represented_dps .append (
1794+ InterfaceDataPoint (
1795+ pdb_id = pdb_id ,
1796+ interface_id = dp_id ,
1797+ )
1798+ )
17311799
1732- # Add this for logging purposes
1733- structure_data_entry .chains [chain_id ].source_subset = set_name
1800+ # Step 5: Pick random representatives from non-represented clusters
1801+ subsampled_dps = select_one_per_cluster (
1802+ non_represented_dps , val_dataset_cache , random_seed = random_seed
1803+ )
17341804
1735- # Activate metrics for all interfaces in the monomer/multimer sets
1736- for interface_id in set_structure_data [pdb_id ].interfaces :
1737- structure_data_entry .interfaces [interface_id ].use_metrics = True
1805+ # Step 6: Mark selected representatives to use metrics as well
1806+ for dp in subsampled_dps :
1807+ pdb_id = dp .pdb_id
1808+ structure_data = val_dataset_cache .structure_data [pdb_id ]
17381809
1739- # Add this for logging purposes
1740- structure_data_entry .interfaces [interface_id ].source_subset = set_name
1810+ if isinstance (dp , ChainDataPoint ):
1811+ dp_data = structure_data .chains [dp .chain_id ]
1812+ else :
1813+ dp_data = structure_data .interfaces [dp .interface_id ]
17411814
1742- unfiltered_cache .structure_data = structure_data
1815+ # Turn on metrics and store source of original PDB-ID
1816+ dp_data .use_metrics = True
1817+ dp_data .source_subset = pdb_id_to_source [pdb_id ]
17431818
17441819
17451820def filter_only_ligand_ligand_metrics (
0 commit comments