-
Notifications
You must be signed in to change notification settings - Fork 261
Expand file tree
/
Copy pathchai1.py
More file actions
1053 lines (929 loc) · 37.5 KB
/
chai1.py
File metadata and controls
1053 lines (929 loc) · 37.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.
import itertools
import logging
import math
from collections import Counter
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, Sequence
import numpy as np
import torch
from einops import einsum, rearrange, repeat
from torch import Tensor
from tqdm import tqdm
from chai_lab.data.collate.collate import Collate
from chai_lab.data.collate.utils import AVAILABLE_MODEL_SIZES
from chai_lab.data.dataset.all_atom_feature_context import (
MAX_MSA_DEPTH,
MAX_NUM_TEMPLATES,
AllAtomFeatureContext,
)
from chai_lab.data.dataset.constraints.restraint_context import (
RestraintContext,
load_manual_restraints_for_chai1,
)
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context
from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.utils import (
subsample_and_reorder_msa_feats_n_mask,
)
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
from chai_lab.data.dataset.structure.bond_utils import (
get_atom_covalent_bond_pairs_from_constraints,
)
from chai_lab.data.dataset.templates.context import (
TemplateContext,
get_template_context,
)
from chai_lab.data.features.feature_factory import FeatureFactory
from chai_lab.data.features.feature_type import FeatureType
from chai_lab.data.features.generators.atom_element import AtomElementOneHot
from chai_lab.data.features.generators.atom_name import AtomNameOneHot
from chai_lab.data.features.generators.base import EncodingType
from chai_lab.data.features.generators.blocked_atom_pair_distances import (
BlockedAtomPairDistances,
BlockedAtomPairDistogram,
)
from chai_lab.data.features.generators.docking import DockingRestraintGenerator
from chai_lab.data.features.generators.esm_generator import ESMEmbeddings
from chai_lab.data.features.generators.identity import Identity
from chai_lab.data.features.generators.is_cropped_chain import ChainIsCropped
from chai_lab.data.features.generators.missing_chain_contact import MissingChainContact
from chai_lab.data.features.generators.msa import (
IsPairedMSAGenerator,
MSADataSourceGenerator,
MSADeletionMeanGenerator,
MSADeletionValueGenerator,
MSAFeatureGenerator,
MSAHasDeletionGenerator,
MSAProfileGenerator,
)
from chai_lab.data.features.generators.ref_pos import RefPos
from chai_lab.data.features.generators.relative_chain import RelativeChain
from chai_lab.data.features.generators.relative_entity import RelativeEntity
from chai_lab.data.features.generators.relative_sep import RelativeSequenceSeparation
from chai_lab.data.features.generators.relative_token import RelativeTokenSeparation
from chai_lab.data.features.generators.residue_type import ResidueType
from chai_lab.data.features.generators.structure_metadata import (
IsDistillation,
TokenBFactor,
TokenPLDDT,
)
from chai_lab.data.features.generators.templates import (
TemplateDistogramGenerator,
TemplateMaskGenerator,
TemplateResTypeGenerator,
TemplateUnitVectorGenerator,
)
from chai_lab.data.features.generators.token_bond import TokenBondRestraint
from chai_lab.data.features.generators.token_dist_restraint import (
TokenDistanceRestraint,
)
from chai_lab.data.features.generators.token_pair_pocket_restraint import (
TokenPairPocketRestraint,
)
from chai_lab.data.io.cif_utils import _CHAIN_VOCAB, get_chain_letter, save_to_cif
from chai_lab.data.parsing.restraints import parse_pairwise_table
from chai_lab.data.parsing.structure.entity_type import EntityType
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
from chai_lab.model.utils import center_random_augmentation
from chai_lab.ranking.frames import get_frames_and_mask
from chai_lab.ranking.rank import SampleRanking, get_scores, rank
from chai_lab.utils.paths import chai1_component
from chai_lab.utils.plot import plot_msa
from chai_lab.utils.tensor_utils import move_data_to_device, set_seed, und_self
from chai_lab.utils.typing import Float, typecheck
class UnsupportedInputError(RuntimeError):
pass
class ModuleWrapper:
def __init__(self, jit_module):
self.jit_module = jit_module
def forward(
self,
crop_size: int,
*,
return_on_cpu=False,
move_to_device: torch.device | None = None,
**kw,
):
f = getattr(self.jit_module, f"forward_{crop_size}")
if move_to_device is not None:
result = f(**move_data_to_device(kw, device=move_to_device))
else:
result = f(**kw)
if return_on_cpu:
return move_data_to_device(result, device=torch.device("cpu"))
else:
return result
def load_exported(comp_key: str, device: torch.device) -> ModuleWrapper:
torch.jit.set_fusion_strategy([("STATIC", 0), ("DYNAMIC", 0)])
local_path = chai1_component(comp_key)
assert isinstance(device, torch.device)
if device != torch.device("cuda:0"):
# load on cpu first, then move to device
return ModuleWrapper(torch.jit.load(local_path, map_location="cpu").to(device))
else:
# skip loading on CPU.
return ModuleWrapper(torch.jit.load(local_path).to(device))
_component_cache: dict[str, ModuleWrapper] = {}
@contextmanager
def _component_moved_to(
comp_key: str, device: torch.device
) -> Generator[ModuleWrapper, None, None]:
# Transiently moves module to provided device, then moves to CPU.
# Much faster than reloading module from disk.
if comp_key not in _component_cache:
_component_cache[comp_key] = load_exported(comp_key, device)
component = _component_cache[comp_key]
component.jit_module.to(device)
yield component
component.jit_module.to("cpu")
# %%
# Create feature factory
feature_generators = dict(
RelativeSequenceSeparation=RelativeSequenceSeparation(sep_bins=None),
RelativeTokenSeparation=RelativeTokenSeparation(r_max=32),
RelativeEntity=RelativeEntity(),
RelativeChain=RelativeChain(),
ResidueType=ResidueType(num_res_ty=32, key="token_residue_type"),
ESMEmbeddings=ESMEmbeddings(), # TODO: this can probably be the identity
BlockedAtomPairDistogram=BlockedAtomPairDistogram(),
InverseSquaredBlockedAtomPairDistances=BlockedAtomPairDistances(
transform="inverse_squared",
encoding_ty=EncodingType.IDENTITY,
),
AtomRefPos=RefPos(),
AtomRefCharge=Identity(
key="inputs/atom_ref_charge",
ty=FeatureType.ATOM,
dim=1,
can_mask=False,
),
AtomRefMask=Identity(
key="inputs/atom_ref_mask",
ty=FeatureType.ATOM,
dim=1,
can_mask=False,
),
AtomRefElement=AtomElementOneHot(max_atomic_num=128),
AtomNameOneHot=AtomNameOneHot(),
TemplateMask=TemplateMaskGenerator(),
TemplateUnitVector=TemplateUnitVectorGenerator(),
TemplateResType=TemplateResTypeGenerator(),
TemplateDistogram=TemplateDistogramGenerator(),
TokenDistanceRestraint=TokenDistanceRestraint(
include_probability=1.0,
size=0.33,
min_dist=6.0,
max_dist=30.0,
num_rbf_radii=6,
),
DockingConstraintGenerator=DockingRestraintGenerator(
include_probability=0.0,
structure_dropout_prob=0.75,
chain_dropout_prob=0.75,
),
TokenPairPocketRestraint=TokenPairPocketRestraint(
size=0.33,
include_probability=1.0,
min_dist=6.0,
max_dist=20.0,
coord_noise=0.0,
num_rbf_radii=6,
),
MSAProfile=MSAProfileGenerator(),
MSADeletionMean=MSADeletionMeanGenerator(),
IsDistillation=IsDistillation(),
TokenBFactor=TokenBFactor(include_prob=0.0),
TokenPLDDT=TokenPLDDT(include_prob=0.0),
ChainIsCropped=ChainIsCropped(),
MissingChainContact=MissingChainContact(contact_threshold=6.0),
MSAOneHot=MSAFeatureGenerator(),
MSAHasDeletion=MSAHasDeletionGenerator(),
MSADeletionValue=MSADeletionValueGenerator(),
IsPairedMSA=IsPairedMSAGenerator(),
MSADataSource=MSADataSourceGenerator(),
)
feature_factory = FeatureFactory(feature_generators)
# %%
# Config
class DiffusionConfig:
S_churn: float = 80
S_tmin: float = 4e-4
S_tmax: float = 80.0
S_noise: float = 1.003
sigma_data: float = 16.0
second_order: bool = True
# %%
# Input validation
def raise_if_too_many_tokens(n_actual_tokens: int):
if n_actual_tokens > max(AVAILABLE_MODEL_SIZES):
raise UnsupportedInputError(
f"Too many tokens in input: {n_actual_tokens} > {max(AVAILABLE_MODEL_SIZES)}. "
"Please limit the length of the input sequence."
)
def raise_if_too_many_templates(n_actual_templates: int):
if n_actual_templates > MAX_NUM_TEMPLATES:
raise UnsupportedInputError(
f"Too many templates in input: {n_actual_templates} > {MAX_NUM_TEMPLATES}. "
"Please limit the number of templates."
)
def raise_if_msa_too_deep(msa_depth: int):
if msa_depth > MAX_MSA_DEPTH:
raise UnsupportedInputError(
f"MSA too deep: {msa_depth} > {MAX_MSA_DEPTH}. "
"Please limit the MSA depth."
)
# %%
# Inference logic
@typecheck
@dataclass(frozen=True)
class StructureCandidates:
# We provide candidates generated by a model,
# with confidence predictions and ranking scores.
# Predicted structure is a candidate with the highest score.
# locations of CIF files, one file per candidate
cif_paths: list[Path]
# scores for each of candidates + info that was used for scoring.
ranking_data: list[SampleRanking]
# iff MSA search was performed, we also save a plot as PDF
msa_coverage_plot_path: Path | None
# Predicted aligned error(PAE)
pae: Float[Tensor, "candidate num_tokens num_tokens"]
# Predicted distance error (PDE)
pde: Float[Tensor, "candidate num_tokens num_tokens"]
# Predicted local distance difference test (pLDDT)
plddt: Float[Tensor, "candidate num_tokens"]
def __post_init__(self):
assert len(self.cif_paths) == len(self.ranking_data) == self.pae.shape[0]
def sorted(self) -> "StructureCandidates":
"""Sort by aggregate score from most to least confident."""
agg_scores = torch.concatenate([rd.aggregate_score for rd in self.ranking_data])
idx = torch.argsort(agg_scores, descending=True) # Higher scores are better
return StructureCandidates(
cif_paths=[self.cif_paths[i] for i in idx],
ranking_data=[self.ranking_data[i] for i in idx],
msa_coverage_plot_path=self.msa_coverage_plot_path,
pae=self.pae[idx],
pde=self.pde[idx],
plddt=self.plddt[idx],
)
@classmethod
def concat(
cls, candidates: Sequence["StructureCandidates"]
) -> "StructureCandidates":
return cls(
cif_paths=list(
itertools.chain.from_iterable([c.cif_paths for c in candidates])
),
ranking_data=list(
itertools.chain.from_iterable([c.ranking_data for c in candidates])
),
msa_coverage_plot_path=candidates[0].msa_coverage_plot_path,
pae=torch.cat([c.pae for c in candidates]),
pde=torch.cat([c.pde for c in candidates]),
plddt=torch.cat([c.plddt for c in candidates]),
)
def make_all_atom_feature_context(
fasta_file: Path,
*,
output_dir: Path,
entity_name_as_subchain: bool = False,
use_esm_embeddings: bool = True,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
use_templates_server: bool = False,
templates_path: Path | None = None,
esm_device: torch.device = torch.device("cpu"),
):
assert not (
use_msa_server and msa_directory
), "Cannot specify both MSA server and directory"
assert not (
use_templates_server and templates_path
), "Cannot specify both templates server and path"
# Prepare inputs
assert fasta_file.exists(), fasta_file
fasta_inputs = read_inputs(fasta_file, length_limit=None)
assert len(fasta_inputs) > 0, "No inputs found in fasta file"
for name, count in Counter([inp.entity_name for inp in fasta_inputs]).items():
if count > 1:
raise UnsupportedInputError(
f"{name=} used more than once in inputs. Each entity must have a unique name"
)
# Load structure context
chains = load_chains_from_raw(
fasta_inputs, entity_name_as_subchain=entity_name_as_subchain
)
del fasta_inputs # Do not reference inputs after creating chains from them
merged_context = AllAtomStructureContext.merge(
[c.structure_context for c in chains]
)
n_actual_tokens = merged_context.num_tokens
raise_if_too_many_tokens(n_actual_tokens)
# Generated and/or load MSAs
protein_sequences = [
chain.entity_data.sequence
for chain in chains
if chain.entity_data.entity_type == EntityType.PROTEIN
]
if use_msa_server:
msa_dir = output_dir / "msas"
msa_dir.mkdir(parents=True, exist_ok=False)
generate_colabfold_msas(
protein_seqs=protein_sequences,
msa_dir=msa_dir,
search_templates=use_templates_server,
msa_server_url=msa_server_url,
)
if use_templates_server and protein_sequences:
# Override templates path with server path
assert templates_path is None
templates_path = msa_dir / "all_chain_templates.m8"
assert templates_path.is_file()
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_dir
)
elif msa_directory is not None:
msa_context, msa_profile_context = get_msa_contexts(
chains, msa_directory=msa_directory
)
else:
msa_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
msa_profile_context = MSAContext.create_empty(
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
)
assert (
msa_context.num_tokens == merged_context.num_tokens
), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}"
# Load templates
if templates_path is None:
if protein_sequences:
assert not use_templates_server, "Server should have written a path"
template_context = TemplateContext.empty(
n_tokens=n_actual_tokens,
n_templates=MAX_NUM_TEMPLATES,
)
else:
# NOTE templates m8 file should contain hits with query name matching chain entity_names
# or the hash of the chain sequence. When we query the server, we use the hash of the
# sequence to identify each hit.
template_context = get_template_context(
chains=chains,
use_sequence_hash_for_lookup=use_templates_server,
template_hits_m8=templates_path,
)
# Load ESM embeddings
if use_esm_embeddings:
embedding_context = get_esm_embedding_context(chains, device=esm_device)
else:
embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens)
# Constraints
if constraint_path is not None:
# Handles contact and pocket restraints
pairs = parse_pairwise_table(constraint_path)
restraint_context = load_manual_restraints_for_chai1(
chains,
crop_idces=None,
provided_constraints=pairs,
)
# Handle covalent bond restraints
cov_a, cov_b = get_atom_covalent_bond_pairs_from_constraints(
provided_constraints=pairs,
token_residue_index=merged_context.token_residue_index,
token_residue_name=merged_context.token_residue_name,
token_subchain_id=merged_context.subchain_id,
token_asym_id=merged_context.token_asym_id,
atom_token_index=merged_context.atom_token_index,
atom_ref_name=merged_context.atom_ref_name,
)
if cov_a.numel() > 0 and cov_b.numel() > 0:
orig_a, orig_b = merged_context.atom_covalent_bond_indices
if orig_a.numel() == orig_b.numel() == 0:
merged_context.atom_covalent_bond_indices = (cov_a, cov_b)
else:
merged_context.atom_covalent_bond_indices = (
torch.concatenate([orig_a, cov_a]),
torch.concatenate([orig_b, cov_b]),
)
assert (
merged_context.atom_covalent_bond_indices[0].numel()
== merged_context.atom_covalent_bond_indices[1].numel()
> 0
)
else:
restraint_context = RestraintContext.empty()
# Handles leaving atoms for glycan bonds in-place
merged_context.drop_glycan_leaving_atoms_inplace()
# Build final feature context
feature_context = AllAtomFeatureContext(
chains=chains,
structure_context=merged_context,
msa_context=msa_context,
profile_msa_context=msa_profile_context,
template_context=template_context,
embedding_context=embedding_context,
restraint_context=restraint_context,
)
return feature_context
@torch.no_grad()
def run_inference(
fasta_file: Path,
*,
output_dir: Path,
# Configuration for ESM, MSA, constraints, and templates
use_esm_embeddings: bool = True,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_directory: Path | None = None,
constraint_path: Path | None = None,
use_templates_server: bool = False,
template_hits_path: Path | None = None,
# Parameters controlling how we do inference
recycle_msa_subsample: int = 0,
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
num_diffn_samples: int = 5,
num_trunk_samples: int = 1,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
# IO options
fasta_names_as_cif_chains: bool = False,
) -> StructureCandidates:
assert num_trunk_samples > 0 and num_diffn_samples > 0
if output_dir.exists():
assert not any(
output_dir.iterdir()
), f"Output directory {output_dir} is not empty."
torch_device = torch.device(device if device is not None else "cuda:0")
# NOTE if fastas are cif chain names, we use this to parse chains as well
feature_context = make_all_atom_feature_context(
fasta_file=fasta_file,
output_dir=output_dir,
entity_name_as_subchain=fasta_names_as_cif_chains,
use_esm_embeddings=use_esm_embeddings,
use_msa_server=use_msa_server,
msa_server_url=msa_server_url,
msa_directory=msa_directory,
constraint_path=constraint_path,
use_templates_server=use_templates_server,
templates_path=template_hits_path,
esm_device=torch_device,
)
all_candidates: list[StructureCandidates] = []
for trunk_idx in range(num_trunk_samples):
logging.info(f"Trunk sample {trunk_idx + 1}/{num_trunk_samples}")
cand = run_folding_on_context(
feature_context,
output_dir=(
output_dir / f"trunk_{trunk_idx}"
if num_trunk_samples > 1
else output_dir
),
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
num_diffn_samples=num_diffn_samples,
recycle_msa_subsample=recycle_msa_subsample,
seed=seed + trunk_idx if seed is not None else None,
device=torch_device,
low_memory=low_memory,
entity_names_as_chain_names_in_output_cif=fasta_names_as_cif_chains,
)
all_candidates.append(cand)
return StructureCandidates.concat(all_candidates)
def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor:
return torch.linspace(min_bin, max_bin, 2 * no_bins + 1)[1::2]
@torch.no_grad()
def run_folding_on_context(
feature_context: AllAtomFeatureContext,
*,
output_dir: Path,
# expose some params for easy tweaking
recycle_msa_subsample: int = 0,
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
# all diffusion samples come from the same trunk
num_diffn_samples: int = 5,
entity_names_as_chain_names_in_output_cif: bool = False,
seed: int | None = None,
device: torch.device | None = None,
low_memory: bool,
) -> StructureCandidates:
"""
Function for in-depth explorations.
User completely controls folding inputs.
"""
# Set seed
if seed is not None:
set_seed([seed])
if device is None:
device = torch.device("cuda:0")
# Clear memory
torch.cuda.empty_cache()
##
## Validate inputs
##
n_actual_tokens = feature_context.structure_context.num_tokens
raise_if_too_many_tokens(n_actual_tokens)
raise_if_too_many_templates(feature_context.template_context.num_templates)
raise_if_msa_too_deep(feature_context.msa_context.depth)
# NOTE profile MSA used only for statistics; no depth check
feature_context.structure_context.report_bonds()
if entity_names_as_chain_names_in_output_cif:
# Ensure that entity names are unique and are valid chain names
entity_names: list[str] = [
chain.entity_data.entity_name for chain in feature_context.chains
]
assert len(set(entity_names)) == len(
entity_names
), f"Using entity names for cif chains, but got duplicates: {entity_names}"
assert all(e in _CHAIN_VOCAB for e in entity_names), (
"Using entity names for cif chains, but got invalid names "
f"{entity_names}; must be in {_CHAIN_VOCAB}"
)
##
## Prepare batch
##
# Collate inputs into batch
collator = Collate(
feature_factory=feature_factory,
num_key_atoms=128,
num_query_atoms=32,
)
feature_contexts = [feature_context]
batch_size = len(feature_contexts)
batch = collator(feature_contexts)
if not low_memory:
batch = move_data_to_device(batch, device=device)
# Get features and inputs from batch
features = {name: feature for name, feature in batch["features"].items()}
inputs = batch["inputs"]
block_indices_h = inputs["block_atom_pair_q_idces"]
block_indices_w = inputs["block_atom_pair_kv_idces"]
atom_single_mask = inputs["atom_exists_mask"]
atom_token_indices = inputs["atom_token_index"].long()
token_single_mask = inputs["token_exists_mask"]
token_pair_mask = und_self(token_single_mask, "b i, b j -> b i j")
token_reference_atom_index = inputs["token_ref_atom_index"]
atom_within_token_index = inputs["atom_within_token_index"]
msa_mask = inputs["msa_mask"]
template_input_masks = und_self(
inputs["template_mask"], "b t n1, b t n2 -> b t n1 n2"
)
block_atom_pair_mask = inputs["block_atom_pair_mask"]
##
## Load exported models
##
_, _, model_size = msa_mask.shape
assert model_size in AVAILABLE_MODEL_SIZES
##
## Run the features through the feature embedder
##
with _component_moved_to("feature_embedding.pt", device) as feature_embedding:
embedded_features = feature_embedding.forward(
crop_size=model_size,
move_to_device=device,
return_on_cpu=low_memory,
**features,
)
token_single_input_feats = embedded_features["TOKEN"]
token_pair_input_feats, token_pair_structure_input_feats = embedded_features[
"TOKEN_PAIR"
].chunk(2, dim=-1)
atom_single_input_feats, atom_single_structure_input_feats = embedded_features[
"ATOM"
].chunk(2, dim=-1)
block_atom_pair_input_feats, block_atom_pair_structure_input_feats = (
embedded_features["ATOM_PAIR"].chunk(2, dim=-1)
)
template_input_feats = embedded_features["TEMPLATES"]
msa_input_feats = embedded_features["MSA"]
##
## Bond feature generator
## Separate from other feature embeddings due to export limitations
##
bond_ft_gen = TokenBondRestraint()
bond_ft = bond_ft_gen.generate(batch=batch).data
with _component_moved_to("bond_loss_input_proj.pt", device) as bond_loss_input_proj:
trunk_bond_feat, structure_bond_feat = bond_loss_input_proj.forward(
return_on_cpu=low_memory,
move_to_device=device,
crop_size=model_size,
input=bond_ft,
).chunk(2, dim=-1)
token_pair_input_feats += trunk_bond_feat
token_pair_structure_input_feats += structure_bond_feat
##
## Run the inputs through the token input embedder
##
with _component_moved_to("token_embedder.pt", device) as token_input_embedder:
token_input_embedder_outputs: tuple[Tensor, ...] = token_input_embedder.forward(
return_on_cpu=low_memory,
move_to_device=device,
token_single_input_feats=token_single_input_feats,
token_pair_input_feats=token_pair_input_feats,
atom_single_input_feats=atom_single_input_feats,
block_atom_pair_feat=block_atom_pair_input_feats,
block_atom_pair_mask=block_atom_pair_mask,
block_indices_h=block_indices_h,
block_indices_w=block_indices_w,
atom_single_mask=atom_single_mask,
atom_token_indices=atom_token_indices,
crop_size=model_size,
)
token_single_initial_repr, token_single_structure_input, token_pair_initial_repr = (
token_input_embedder_outputs
)
##
## Run the input representations through the trunk
##
# Recycle the representations by feeding the output back into the trunk as input for
# the subsequent recycle
token_single_trunk_repr = token_single_initial_repr
token_pair_trunk_repr = token_pair_initial_repr
for _ in tqdm(range(num_trunk_recycles), desc="Trunk recycles"):
subsampled_msa_input_feats, subsampled_msa_mask = None, None
if recycle_msa_subsample > 0:
subsampled_msa_input_feats, subsampled_msa_mask = (
subsample_and_reorder_msa_feats_n_mask(
msa_input_feats,
msa_mask,
)
)
with _component_moved_to("trunk.pt", device) as trunk:
(token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward(
move_to_device=device,
token_single_trunk_initial_repr=token_single_initial_repr,
token_pair_trunk_initial_repr=token_pair_initial_repr,
token_single_trunk_repr=token_single_trunk_repr, # recycled
token_pair_trunk_repr=token_pair_trunk_repr, # recycled
msa_input_feats=(
subsampled_msa_input_feats
if subsampled_msa_input_feats is not None
else msa_input_feats
),
msa_mask=(
subsampled_msa_mask if subsampled_msa_mask is not None else msa_mask
),
template_input_feats=template_input_feats,
template_input_masks=template_input_masks,
token_single_mask=token_single_mask,
token_pair_mask=token_pair_mask,
crop_size=model_size,
)
# in case trunk fragmented mem too much
torch.cuda.empty_cache()
##
## Denoise the trunk representation by passing it through the diffusion module
##
atom_single_mask = atom_single_mask.to(device)
static_diffusion_inputs = dict(
token_single_initial_repr=token_single_structure_input.float(),
token_pair_initial_repr=token_pair_structure_input_feats.float(),
token_single_trunk_repr=token_single_trunk_repr.float(),
token_pair_trunk_repr=token_pair_trunk_repr.float(),
atom_single_input_feats=atom_single_structure_input_feats.float(),
atom_block_pair_input_feats=block_atom_pair_structure_input_feats.float(),
atom_single_mask=atom_single_mask,
atom_block_pair_mask=block_atom_pair_mask,
token_single_mask=token_single_mask,
block_indices_h=block_indices_h,
block_indices_w=block_indices_w,
atom_token_indices=atom_token_indices,
)
static_diffusion_inputs = move_data_to_device(
static_diffusion_inputs, device=device
)
def _denoise(
diff_mod: ModuleWrapper, atom_pos: Tensor, sigma: Tensor, ds: int
) -> Tensor:
# verified manually that ds dimension can be arbitrary in diff module
atom_noised_coords = rearrange(
atom_pos, "(b ds) ... -> b ds ...", ds=ds
).contiguous()
noise_sigma = repeat(sigma, " -> b ds", b=batch_size, ds=ds)
return diff_mod.forward(
atom_noised_coords=atom_noised_coords.float(),
noise_sigma=noise_sigma.float(),
crop_size=model_size,
**static_diffusion_inputs,
)
inference_noise_schedule = InferenceNoiseSchedule(
s_max=DiffusionConfig.S_tmax,
s_min=4e-4,
p=7.0,
sigma_data=DiffusionConfig.sigma_data,
)
sigmas = inference_noise_schedule.get_schedule(
device=device, num_timesteps=num_diffn_timesteps
)
gammas = torch.where(
(sigmas >= DiffusionConfig.S_tmin) & (sigmas <= DiffusionConfig.S_tmax),
min(DiffusionConfig.S_churn / num_diffn_timesteps, math.sqrt(2) - 1),
0.0,
)
sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
# Initial atom positions
_, num_atoms = atom_single_mask.shape
atom_pos = sigmas[0] * torch.randn(
batch_size * num_diffn_samples, num_atoms, 3, device=device
)
with _component_moved_to("diffusion_module.pt", device=device) as diffusion_module:
for sigma_curr, sigma_next, gamma_curr in tqdm(
sigmas_and_gammas, desc="Diffusion steps"
):
# Center coords
atom_pos = center_random_augmentation(
atom_pos,
atom_single_mask=repeat(
atom_single_mask,
"b a -> (b ds) a",
ds=num_diffn_samples,
),
)
# Alg 2. lines 4-6
noise = DiffusionConfig.S_noise * torch.randn(
atom_pos.shape, device=atom_pos.device
)
sigma_hat = sigma_curr + gamma_curr * sigma_curr
atom_pos_noise = (sigma_hat**2 - sigma_curr**2).clamp_min(1e-6).sqrt()
atom_pos_hat = atom_pos + noise * atom_pos_noise
# Lines 7-8
denoised_pos = _denoise(
diff_mod=diffusion_module,
atom_pos=atom_pos_hat,
sigma=sigma_hat,
ds=num_diffn_samples,
)
d_i = (atom_pos_hat - denoised_pos) / sigma_hat
atom_pos = atom_pos_hat + (sigma_next - sigma_hat) * d_i
# Lines 9-11
if sigma_next != 0 and DiffusionConfig.second_order: # second order update
denoised_pos = _denoise(
diff_mod=diffusion_module,
atom_pos=atom_pos,
sigma=sigma_next,
ds=num_diffn_samples,
)
d_i_prime = (atom_pos - denoised_pos) / sigma_next
atom_pos = atom_pos + (sigma_next - sigma_hat) * ((d_i_prime + d_i) / 2)
del static_diffusion_inputs
torch.cuda.empty_cache()
##
## Run the confidence model
##
with _component_moved_to("confidence_head.pt", device=device) as confidence_head:
confidence_outputs: list[tuple[Tensor, ...]] = [
confidence_head.forward(
move_to_device=device,
token_single_input_repr=token_single_initial_repr,
token_single_trunk_repr=token_single_trunk_repr,
token_pair_trunk_repr=token_pair_trunk_repr,
token_single_mask=token_single_mask,
atom_single_mask=atom_single_mask,
atom_coords=atom_pos[ds : ds + 1],
token_reference_atom_index=token_reference_atom_index,
atom_token_index=atom_token_indices,
atom_within_token_index=atom_within_token_index,
crop_size=model_size,
)
for ds in range(num_diffn_samples)
]
pae_logits, pde_logits, plddt_logits = [
torch.cat(single_sample, dim=0)
for single_sample in zip(*confidence_outputs, strict=True)
]
assert atom_pos.shape[0] == num_diffn_samples
assert pae_logits.shape[0] == num_diffn_samples
def softmax_einsum_and_cpu(
logits: Tensor, bin_mean: Tensor, pattern: str
) -> Tensor:
# utility to compute score from bin logits
res = einsum(
logits.float().softmax(dim=-1), bin_mean.to(logits.device), pattern
)
return res.to(device="cpu")
token_mask_1d = rearrange(token_single_mask, "1 b -> b")
pae_scores = softmax_einsum_and_cpu(
pae_logits[:, token_mask_1d, :, :][:, :, token_mask_1d, :],
_bin_centers(0.0, 32.0, 64),
"b n1 n2 d, d -> b n1 n2",
)
pde_scores = softmax_einsum_and_cpu(
pde_logits[:, token_mask_1d, :, :][:, :, token_mask_1d, :],
_bin_centers(0.0, 32.0, 64),
"b n1 n2 d, d -> b n1 n2",
)
plddt_scores_atom = softmax_einsum_and_cpu(
plddt_logits,
_bin_centers(0, 1, plddt_logits.shape[-1]),
"b a d, d -> b a",
)
# converting per-atom plddt to per-token
[mask] = atom_single_mask.cpu()
[indices] = atom_token_indices.cpu()
def avg_per_token_1d(x):
n = torch.bincount(indices[mask], weights=x[mask])
d = torch.bincount(indices[mask]).clamp(min=1)
return n / d
plddt_scores = torch.stack([avg_per_token_1d(x) for x in plddt_scores_atom])
##
## Write the outputs
##
# Move data to the CPU so we don't hit GPU memory limits
inputs = move_data_to_device(inputs, torch.device("cpu"))
atom_pos = atom_pos.cpu()
plddt_logits = plddt_logits.cpu()
pae_logits = pae_logits.cpu()
# Plot coverage of tokens by MSA, save plot
output_dir.mkdir(parents=True, exist_ok=True)
if feature_context.msa_context.mask.any():
msa_plot_path = plot_msa(
input_tokens=feature_context.structure_context.token_residue_type,
msa_tokens=feature_context.msa_context.tokens,
out_fname=output_dir / "msa_depth.pdf",
)
else:
msa_plot_path = None
cif_paths: list[Path] = []
ranking_data: list[SampleRanking] = []
for idx in range(num_diffn_samples):
##
## Compute ranking scores
##
_, valid_frames_mask = get_frames_and_mask(
atom_pos[idx : idx + 1],
inputs["token_asym_id"],
inputs["token_residue_index"],
inputs["token_backbone_frame_mask"],
inputs["token_centre_atom_index"],
inputs["token_exists_mask"],
inputs["atom_exists_mask"],
inputs["token_backbone_frame_index"],
inputs["atom_token_index"],
)
ranking_outputs: SampleRanking = rank(
atom_pos[idx : idx + 1],
atom_mask=inputs["atom_exists_mask"],
atom_token_index=inputs["atom_token_index"],
token_exists_mask=inputs["token_exists_mask"],
token_asym_id=inputs["token_asym_id"],