Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions openfold/model/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
from functools import partial
import math
import sys
from typing import Optional, List

import torch
Expand Down Expand Up @@ -519,6 +518,9 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
template_embedder = model.template_embedder
template_config = template_embedder.config

# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu = []
n = z.shape[-2]
Expand All @@ -533,15 +535,15 @@ def embed_templates_offload(
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=model.config.template.use_unit_vector,
inf=model.config.template.inf,
eps=model.config.template.eps,
**model.config.template.distogram,
use_unit_vector=template_config.use_unit_vector,
inf=template_config.inf,
eps=template_config.eps,
**template_config.distogram,
).to(z.dtype)
t = model.template_pair_embedder(t)
t = template_embedder.template_pair_embedder(t)

# [*, 1, N, N, C_z]
t = model.template_pair_stack(
t = template_embedder.template_pair_stack(
t.unsqueeze(templ_dim),
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
Expand All @@ -553,8 +555,6 @@ def embed_templates_offload(
_mask_trans=model.config._mask_trans,
)

assert (sys.getrefcount(t) == 2)

pair_embeds_cpu.append(t.cpu())

del t
Expand All @@ -568,7 +568,7 @@ def embed_templates_offload(
]
pair_chunk = torch.cat(pair_chunks, dim=templ_dim).to(device=z.device)
z_chunk = z[..., i: i + template_chunk_size, :, :]
att_chunk = model.template_pointwise_att(
att_chunk = template_embedder.template_pointwise_att(
pair_chunk,
z_chunk,
template_mask=batch["template_mask"].to(dtype=z.dtype),
Expand All @@ -579,19 +579,25 @@ def embed_templates_offload(

del pair_chunks

t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
t_mask = t_mask.reshape(
*t_mask.shape,
*([1] * (len(t.shape) - len(t_mask.shape))),
)

if inplace_safe:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
t *= t_mask
else:
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
t = t * t_mask

ret = {}
if model.config.template.embed_angles:
if template_config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch,
)

# [*, N, C_m]
a = model.template_single_embedder(template_angle_feat)
a = template_embedder.template_single_embedder(template_angle_feat)

ret["template_single_embedding"] = a

Expand Down Expand Up @@ -636,6 +642,9 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
template_embedder = model.template_embedder
template_config = template_embedder.config

# Embed the templates one at a time (with a poor man's vmap)
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
Expand All @@ -654,15 +663,15 @@ def slice_template_tensor(t):
# [*, N, N, C_t]
t = build_template_pair_feat(
template_feats,
use_unit_vector=model.config.template.use_unit_vector,
inf=model.config.template.inf,
eps=model.config.template.eps,
**model.config.template.distogram,
use_unit_vector=template_config.use_unit_vector,
inf=template_config.inf,
eps=template_config.eps,
**template_config.distogram,
).to(z.dtype)

# [*, S_t, N, N, C_z]
t = model.template_pair_embedder(t)
t = model.template_pair_stack(
t = template_embedder.template_pair_embedder(t)
t = template_embedder.template_pair_stack(
t,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
Expand All @@ -674,7 +683,7 @@ def slice_template_tensor(t):
_mask_trans=model.config._mask_trans,
)

t = model.template_pointwise_att(
t = template_embedder.template_pointwise_att(
t,
z,
template_mask=template_feats["template_mask"].to(dtype=z.dtype),
Expand All @@ -694,19 +703,25 @@ def slice_template_tensor(t):

del t

t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
t_mask = t_mask.reshape(
*t_mask.shape,
*([1] * (len(out_tensor.shape) - len(t_mask.shape))),
)

if inplace_safe:
out_tensor *= (torch.sum(batch["template_mask"], dim=-1) > 0)
out_tensor *= t_mask
else:
out_tensor = out_tensor * (torch.sum(batch["template_mask"], dim=-1) > 0)
out_tensor = out_tensor * t_mask

ret = {}
if model.config.template.embed_angles:
if template_config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch,
)

# [*, N, C_m]
a = model.template_single_embedder(template_angle_feat)
a = template_embedder.template_single_embedder(template_angle_feat)

ret["template_single_embedding"] = a

Expand Down
19 changes: 10 additions & 9 deletions openfold/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def trace_model_(model, sample_input):
model.extra_msa_stack.blocks = extra_msa_blocks[:1]

if(model.template_config.enabled):
template_pair_stack_blocks = model.template_pair_stack.blocks
model.template_pair_stack.blocks = template_pair_stack_blocks[:1]
template_pair_stack = model.template_embedder.template_pair_stack
template_pair_stack_blocks = template_pair_stack.blocks
template_pair_stack.blocks = template_pair_stack_blocks[:1]

single_recycling_iter_input = tensor_tree_map(
lambda t: t[..., :1], sample_input,
Expand All @@ -109,7 +110,7 @@ def trace_model_(model, sample_input):
del evoformer_blocks, extra_msa_blocks

if(model.template_config.enabled):
model.template_pair_stack.blocks = template_pair_stack_blocks
template_pair_stack.blocks = template_pair_stack_blocks
del template_pair_stack_blocks

def get_tuned_chunk_size(module):
Expand All @@ -132,9 +133,9 @@ def get_tuned_chunk_size(module):

if(model.template_config.enabled):
template_pair_stack_chunk_size = model.globals.chunk_size
if(model.template_pair_stack.chunk_size_tuner is not None):
if(template_pair_stack.chunk_size_tuner is not None):
template_pair_stack_chunk_size = get_tuned_chunk_size(
model.template_pair_stack
template_pair_stack
)

def trace_block(block, block_inputs):
Expand Down Expand Up @@ -405,7 +406,7 @@ def verify_arg_order(fn, arg_list):
# )),
# ]
# verify_arg_order(
# model.template_pair_stack.blocks[0].forward,
# model.template_embedder.template_pair_stack.blocks[0].forward,
# template_pair_stack_arg_tuples
# )
# template_pair_stack_args = [
Expand All @@ -414,12 +415,12 @@ def verify_arg_order(fn, arg_list):
#
# with torch.no_grad():
# traced_template_pair_stack = []
# for b in model.template_pair_stack.blocks:
# for b in model.template_embedder.template_pair_stack.blocks:
# traced_block = trace_block(b, template_pair_stack_args)
# traced_template_pair_stack.append(traced_block)
#
# del model.template_pair_stack.blocks
# model.template_pair_stack.blocks = traced_template_pair_stack
# del model.template_embedder.template_pair_stack.blocks
# model.template_embedder.template_pair_stack.blocks = traced_template_pair_stack

# We need to do another dry run after tracing to allow the model to reach
# top speeds. Why, I don't know.
Expand Down
Loading