Skip to content

Commit ca02836

Browse files
authored
Merge pull request #45 from uw-ipd/april23_updates
April23 updates
2 parents 9ebd142 + e4cb3d7 commit ca02836

29 files changed

Lines changed: 1856 additions & 1365 deletions

README.md

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# RF2NA
22
GitHub repo for RoseTTAFold2 with nucleic acids
33

4+
**New: April 13, 2023 v0.2**
5+
* Updated weights (https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz) for better prediction of homodimer:DNA interactions and better DNA-specific sequence recognition
6+
* Bugfixes in MSA generation pipeline
7+
* Support for paired protein/RNA MSAs
8+
49
## Installation
510

611
1. Clone the package
@@ -25,9 +30,9 @@ python setup.py install
2530
3. Download pre-trained weights under network directory
2631
```
2732
cd network
28-
wget https://files.ipd.uw.edu/dimaio/RF2NA_sep22.tgz
29-
tar xvfz RF2NA_sep22.tgz
30-
ls weights/ # it should contain a 800mb weights file
33+
wget https://files.ipd.uw.edu/dimaio/RF2NA_apr23.tgz
34+
tar xvfz RF2NA_apr23.tgz
35+
ls weights/ # it should contain a 1.1GB weights file
3136
cd ..
3237
```
3338

@@ -62,7 +67,7 @@ wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/rfam/rfam_anno
6267
wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/id_mapping/id_mapping.tsv.gz
6368
wget ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/sequences/rnacentral_species_specific_ids.fasta.gz
6469
../input_prep/reprocess_rnac.pl id_mapping.tsv.gz rfam_annotations.tsv.gz # ~8 minutes
65-
gunzip -c rnacentral_species_specific_ids.fasta.gz | makeblastdb -in - -dbtype nucl -out rnacentral.fasta -title "RNACentral"
70+
gunzip -c rnacentral_species_specific_ids.fasta.gz | makeblastdb -in - -dbtype nucl -parse_seqids -out rnacentral.fasta -title "RNACentral"
6671
6772
# nt [151G]
6873
update_blastdb.pl --decompress nt
@@ -73,9 +78,15 @@ cd ..
7378
```
7479
conda activate RF2NA
7580
cd example
76-
../run_RF2NA.sh t000_ protein.fa R:RNA.fa
81+
# run Protein/RNA prediction
82+
../run_RF2NA.sh rna_pred rna_binding_protein.fa R:RNA.fa
83+
# run Protein/dsDNA prediction
84+
../run_RF2NA.sh dna_pred dna_binding_protein.fa D:DNA.fa
85+
7786
```
78-
The first argument to the script is the output folder; remaining arguments are fasta files for individual chains in the structure. Use the tags `P:xxx.fa` `R:xxx.fa` `D:xxx.fa` to specify protein, RNA, DNA respectively (default is protein). Each chain is a separate file (e.g., for double-stranded DNA, both strands need to be provided as separate fasta files). Outputs are written to the folder `t000_`.
87+
The first argument to the script is the output folder; remaining arguments are fasta files for individual chains in the structure. Use the tags `P:xxx.fa` `R:xxx.fa` `D:xxx.fa` `S:xxx.fa` and `PR:xxx.fa` to specify protein, RNA, dsDNA, ssDNA, and paired protein/RNA respectively (default is protein).
88+
89+
Each chain is a separate file; 'D' will automatically generate a complementary DNA strand to the input strand. Outputs are written to the folder `dna_pred` and `rna_pred`.
7990

8091
## Expected outputs
81-
You will get a prediction with estimated per-residue LDDT in the B-factor column (model_00.pdb)
92+
You will get a prediction with estimated per-residue LDDT in the B-factor column (`models/model_00.pdb`)

RF2na-linux.yml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
11
name: RF2NA
22
channels:
3+
- pytorch
4+
- nvidia
35
- defaults
6+
- conda-forge
47
dependencies:
5-
- python=3.8
6-
- pytorch::pytorch
8+
- python=3.10
9+
- pip
10+
- pytorch
711
- requests
8-
- conda-forge::psutil
9-
- conda-forge::cudatoolkit=11.3
10-
- conda-forge::tqdm
11-
- dglteam::dgl-cuda11.3
12+
- pytorch-cuda=11.7
13+
- dglteam/label/cu117::dgl
14+
- pyg::pyg
1215
- bioconda::mafft
1316
- bioconda::hhsuite
1417
- bioconda::blast
1518
- bioconda::hmmer>=3.3
1619
- bioconda::infernal
1720
- bioconda::cd-hit
1821
- bioconda::csblast
19-
- biocore::psipred=4.01
20-
- biocore::blast-legacy=2.2.26
22+
- pip:
23+
- psutil
24+
- tqdm

SE3Transformer/se3_transformer/model/layers/attention.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ def forward(
7878

7979
with nvtx_range('attention dot product + softmax'):
8080
# Compute attention weights (softmax of inner product between key and query)
81-
with torch.cuda.amp.autocast(False):
82-
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
83-
edge_weights /= np.sqrt(self.key_fiber.num_features)
84-
edge_weights = edge_softmax(graph, edge_weights)
85-
edge_weights = edge_weights[..., None, None]
81+
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
82+
edge_weights /= np.sqrt(self.key_fiber.num_features)
83+
edge_weights = edge_softmax(graph, edge_weights)
84+
edge_weights = edge_weights[..., None, None]
8685

8786
with nvtx_range('weighted sum'):
8887
if isinstance(value, Tensor):
@@ -158,6 +157,11 @@ def forward(
158157
basis: Dict[str, Tensor]
159158
):
160159
with nvtx_range('AttentionBlockSE3'):
160+
#print ('AttentionBlockSE3 node_features',[torch.sum(torch.isnan(v)) for v in node_features.values()])
161+
#print ('AttentionBlockSE3 edge_features',[torch.sum(torch.isnan(v)) for v in edge_features.values()])
162+
#print ('AttentionBlockSE3 node_features',[torch.max(torch.abs(v)) for v in node_features.values()])
163+
#print ('AttentionBlockSE3 edge_features',[torch.max(torch.abs(v)) for v in edge_features.values()])
164+
161165
with nvtx_range('keys / values'):
162166
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
163167
key, value = self._get_key_value_from_fused(fused_key_value)
@@ -166,9 +170,22 @@ def forward(
166170
with torch.cuda.amp.autocast(False):
167171
query = self.to_query(node_features)
168172

173+
#if (type(value) is dict):
174+
# print ('AttentionBlockSE3 value',[torch.sum(torch.isnan(v)) for v in value.values()])
175+
#else:
176+
# print ('AttentionBlockSE3 value',[torch.sum(torch.isnan(value))])
177+
#if (type(key) is dict):
178+
# print ('AttentionBlockSE3 key',[torch.sum(torch.isnan(k)) for k in key.values()])
179+
#else:
180+
# print ('AttentionBlockSE3 key',[torch.sum(torch.isnan(key))])
181+
#print ('AttentionBlockSE3 query',[torch.sum(torch.isnan(q)) for q in query.values()])
169182
z = self.attention(value, key, query, graph)
183+
#print ('AttentionBlockSE3 b',[torch.sum(torch.isnan(zi)) for zi in z.values()])
170184
z_concat = aggregate_residual(node_features, z, 'cat')
171-
return self.project(z_concat)
185+
#print ('AttentionBlockSE3 c',[torch.sum(torch.isnan(zi)) for zi in z_concat.values()] )
186+
output = self.project(z_concat)
187+
#print ('AttentionBlockSE3 d',[torch.sum(torch.isnan(o)) for o in output.values()] )
188+
return output
172189

173190
def _get_key_value_from_fused(self, fused_key_value):
174191
# Extract keys and queries features from fused features

SE3Transformer/se3_transformer/model/layers/convolution.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,9 @@ def forward(
320320
out = {}
321321
in_features = []
322322

323+
#print ('ConvSE3 node_feats',[torch.sum(torch.isnan(v)) for v in node_feats.values()])
324+
#print ('ConvSE3 edge_feats',[torch.sum(torch.isnan(v)) for v in edge_feats.values()])
325+
323326
# Fetch all input features from edge and node features
324327
for degree_in in self.fiber_in.degrees:
325328
src_node_features = node_feats[str(degree_in)][src]
@@ -358,6 +361,11 @@ def forward(
358361
basis.get(dict_key, None))
359362
out[str(degree_out)] = out_feature
360363

364+
#if (type(out) is dict):
365+
# print ('ConvSE3 out',[torch.sum(torch.isnan(v)) for v in out.values()])
366+
#else:
367+
# print ('ConvSE3 out',[torch.sum(torch.isnan(out))])
368+
361369
for degree_out in self.fiber_out.degrees:
362370
if self.self_interaction and str(degree_out) in self.to_kernel_self:
363371
with nvtx_range(f'self interaction'):
@@ -369,7 +377,9 @@ def forward(
369377
if self.sum_over_edge:
370378
with nvtx_range(f'pooling'):
371379
if isinstance(out, dict):
380+
#print ('ConvSE3 pre-pool',degree_out,torch.sum(torch.isnan(out[str(degree_out)])), out[str(degree_out)].dtype )
372381
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
382+
#print ('ConvSE3 post-pool',degree_out,torch.sum(torch.isnan(out[str(degree_out)])), out[str(degree_out)].dtype )
373383
else:
374384
out = dgl.ops.copy_e_sum(graph, out)
375385
else:

SE3Transformer/se3_transformer/model/layers/norm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
6161
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
6262
with nvtx_range('NormSE3'):
6363
output = {}
64+
#print ('NormSE3 features',[torch.sum(torch.isnan(v)) for v in features.values()])
6465
if hasattr(self, 'group_norm'):
6566
# Compute per-degree norms of features
6667
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
@@ -79,5 +80,6 @@ def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Ten
7980
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
8081
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
8182
output[degree] = new_norm * feat / norm
83+
#print ('NormSE3 output',[torch.sum(torch.isnan(v)) for v in output.values()])
8284

8385
return output

SE3Transformer/se3_transformer/model/transformer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self,
7676
use_layer_norm: bool = True,
7777
tensor_cores: bool = False,
7878
low_memory: bool = False,
79-
populate_edge: bool = True,
79+
populate_edge: Optional[Literal['lin', 'arcsin', 'log', 'zero']] = 'lin',
8080
sum_over_edge: bool = True,
8181
**kwargs):
8282
"""
@@ -168,8 +168,17 @@ def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
168168
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
169169
fully_fused=self.tensor_cores and not self.low_memory)
170170

171-
if self.populate_edge:
171+
if self.populate_edge=='lin':
172172
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
173+
elif self.populate_edge=='arcsin':
174+
r = graph.edata['rel_pos'].norm(dim=-1, keepdim=True)
175+
r = torch.maximum(r, torch.zeros_like(r) + 4.0) - 4.0
176+
r = torch.arcsinh(r)/3.0
177+
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
178+
elif self.populate_edge=='log':
179+
# fd - replace with log(1+x)
180+
r = torch.log( 1 + graph.edata['rel_pos'].norm(dim=-1, keepdim=True) )
181+
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
173182
else:
174183
edge_feats['0'] = torch.cat((edge_feats['0'], torch.zeros_like(edge_feats['0'][:,:1,:])), dim=1)
175184

example/dna_binding_protein.fa

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
> ANTENNAPEDIA HOMEODOMAIN|Drosophila melanogaster (7227)
2+
MERKRGRQTYTRYQTLELEKEFHFNRYLTRRRRIEIAHALSLTERQIKIWFQNRRMKWKKEN

0 commit comments

Comments
 (0)