We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1812fa3 commit 73a5fffCopy full SHA for 73a5fff
primus/backends/torchtitan/models/llama3/model/model.py
@@ -5,12 +5,13 @@
5
###############################################################################
6
7
import torch
8
+from torch.nn.attention.flex_attention import BlockMask
9
from torchtitan.models.llama3.model.model import Attention as TTAttention
10
from torchtitan.models.llama3.model.model import apply_rotary_emb
-from torch.nn.attention.flex_attention import BlockMask
11
12
AttentionMasksType = dict[str, BlockMask] | BlockMask
13
14
+
15
class Attention(TTAttention):
16
def forward(
17
self,
0 commit comments