Skip to content

Commit 7f9c58a

Browse files
atiorhdavidfindlay1
andcommitted
WWDC23 update
Co-authored-by: davidfindlay1 <[email protected]>
1 parent e3875a5 commit 7f9c58a

16 files changed

+790
-278
lines changed

README.md

Lines changed: 153 additions & 128 deletions
Large diffs are not rendered by default.
2.17 MB
Loading

assets/float16_gpu_readmereel.png

2.16 MB
Loading
2.2 MB
Loading
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.0"
1+
__version__ = "1.0.0"
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import logging
2+
3+
logger = logging.getLogger(__name__)
4+
logger.setLevel(logging.INFO)
5+
6+
import torch
7+
8+
def split_einsum(q, k, v, mask, heads, dim_head):
9+
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM
10+
11+
- Implements https://machinelearning.apple.com/research/neural-engine-transformers
12+
- Recommended for ANE
13+
- Marginally slower on GPU
14+
"""
15+
mh_q = [
16+
q[:, head_idx * dim_head:(head_idx + 1) *
17+
dim_head, :, :] for head_idx in range(heads)
18+
] # (bs, dim_head, 1, max_seq_length) * heads
19+
20+
k = k.transpose(1, 3)
21+
mh_k = [
22+
k[:, :, :,
23+
head_idx * dim_head:(head_idx + 1) * dim_head]
24+
for head_idx in range(heads)
25+
] # (bs, max_seq_length, 1, dim_head) * heads
26+
27+
mh_v = [
28+
v[:, head_idx * dim_head:(head_idx + 1) *
29+
dim_head, :, :] for head_idx in range(heads)
30+
] # (bs, dim_head, 1, max_seq_length) * heads
31+
32+
attn_weights = [
33+
torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * (dim_head**-0.5)
34+
for qi, ki in zip(mh_q, mh_k)
35+
] # (bs, max_seq_length, 1, max_seq_length) * heads
36+
37+
if mask is not None:
38+
for head_idx in range(heads):
39+
attn_weights[head_idx] = attn_weights[head_idx] + mask
40+
41+
attn_weights = [
42+
aw.softmax(dim=1) for aw in attn_weights
43+
] # (bs, max_seq_length, 1, max_seq_length) * heads
44+
attn = [
45+
torch.einsum("bkhq,bchk->bchq", wi, vi)
46+
for wi, vi in zip(attn_weights, mh_v)
47+
] # (bs, dim_head, 1, max_seq_length) * heads
48+
49+
attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length)
50+
return attn
51+
52+
53+
CHUNK_SIZE = 512
54+
55+
def split_einsum_v2(q, k, v, mask, heads, dim_head):
56+
""" Attention Implementation backing AttentionImplementations.SPLIT_EINSUM_V2
57+
58+
- Implements https://machinelearning.apple.com/research/neural-engine-transformers
59+
- Recommended for ANE
60+
- Marginally slower on GPU
61+
- Chunks the query sequence to avoid large intermediate tensors and improves ANE performance
62+
"""
63+
query_seq_length = q.size(3)
64+
num_chunks = query_seq_length // CHUNK_SIZE
65+
66+
if num_chunks == 0:
67+
logger.info(
68+
"AttentionImplementations.SPLIT_EINSUM_V2: query sequence too short to chunk "
69+
f"({query_seq_length}<{CHUNK_SIZE}), fall back to AttentionImplementations.SPLIT_EINSUM (safe to ignore)")
70+
return split_einsum(q, k, v, mask, heads, dim_head)
71+
72+
logger.info(
73+
"AttentionImplementations.SPLIT_EINSUM_V2: Splitting query sequence length of "
74+
f"{query_seq_length} into {num_chunks} chunks")
75+
76+
mh_q = [
77+
q[:, head_idx * dim_head:(head_idx + 1) *
78+
dim_head, :, :] for head_idx in range(heads)
79+
] # (bs, dim_head, 1, max_seq_length) * heads
80+
81+
# Chunk the query sequence for each head
82+
mh_q_chunked = [
83+
[h_q[..., chunk_idx * CHUNK_SIZE:(chunk_idx + 1) * CHUNK_SIZE] for chunk_idx in range(num_chunks)]
84+
for h_q in mh_q
85+
] # ((bs, dim_head, 1, QUERY_SEQ_CHUNK_SIZE) * num_chunks) * heads
86+
87+
k = k.transpose(1, 3)
88+
mh_k = [
89+
k[:, :, :,
90+
head_idx * dim_head:(head_idx + 1) * dim_head]
91+
for head_idx in range(heads)
92+
] # (bs, max_seq_length, 1, dim_head) * heads
93+
94+
mh_v = [
95+
v[:, head_idx * dim_head:(head_idx + 1) *
96+
dim_head, :, :] for head_idx in range(heads)
97+
] # (bs, dim_head, 1, max_seq_length) * heads
98+
99+
attn_weights = [
100+
[
101+
torch.einsum("bchq,bkhc->bkhq", [qi_chunk, ki]) * (dim_head**-0.5)
102+
for qi_chunk in h_q_chunked
103+
] for h_q_chunked, ki in zip(mh_q_chunked, mh_k)
104+
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads
105+
106+
attn_weights = [
107+
[aw_chunk.softmax(dim=1) for aw_chunk in aw_chunked]
108+
for aw_chunked in attn_weights
109+
] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads
110+
111+
attn = [
112+
[
113+
torch.einsum("bkhq,bchk->bchq", wi_chunk, vi)
114+
for wi_chunk in wi_chunked
115+
] for wi_chunked, vi in zip(attn_weights, mh_v)
116+
] # ((bs, dim_head, 1, chunk_size) * num_chunks) * heads
117+
118+
attn = torch.cat([
119+
torch.cat(attn_chunked, dim=3) for attn_chunked in attn
120+
], dim=1) # (bs, dim, 1, max_seq_length)
121+
122+
return attn
123+
124+
125+
def original(q, k, v, mask, heads, dim_head):
126+
""" Attention Implementation backing AttentionImplementations.ORIGINAL
127+
128+
- Not recommended for ANE
129+
- Recommended for GPU
130+
"""
131+
bs = q.size(0)
132+
mh_q = q.view(bs, heads, dim_head, -1)
133+
mh_k = k.view(bs, heads, dim_head, -1)
134+
mh_v = v.view(bs, heads, dim_head, -1)
135+
136+
attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k])
137+
attn_weights.mul_(dim_head**-0.5)
138+
139+
if mask is not None:
140+
attn_weights = attn_weights + mask
141+
142+
attn_weights = attn_weights.softmax(dim=3)
143+
144+
attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v])
145+
attn = attn.contiguous().view(bs, heads * dim_head, 1, -1)
146+
return attn
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from python_coreml_stable_diffusion.torch2coreml import _compile_coreml_model
2+
3+
import argparse
4+
import coremltools as ct
5+
import numpy as np
6+
import os
7+
import torch
8+
import torch.nn as nn
9+
10+
# TODO: Read these values off of the NLContextualEmbedding API to enforce dimensions and track API versioning
11+
MAX_SEQUENCE_LENGTH = 256
12+
EMBED_DIM = 512
13+
BATCH_SIZE = 1
14+
15+
def main(args):
16+
# Layer that was trained to map NLContextualEmbedding to your text_encoder.hidden_size dimensionality
17+
text_encoder_projection = torch.jit.load(args.input_path)
18+
19+
# Prepare random inputs for tracing the network before conversion
20+
random_input = torch.randn(BATCH_SIZE, MAX_SEQUENCE_LENGTH, EMBED_DIM)
21+
22+
# Create a class to bake in the reshape operations required to fit the existing model interface
23+
class TextEncoderProjection(nn.Module):
24+
def __init__(self, proj):
25+
super().__init__()
26+
self.proj = proj
27+
28+
def forward(self, x):
29+
return self.proj(x).transpose(1, 2).unsqueeze(2) # BSC, BC1S
30+
31+
# Trace the torch model
32+
text_encoder_projection = torch.jit.trace(TextEncoderProjection(text_encoder_projection), (random_input,))
33+
34+
# Convert the model to Core ML
35+
mlpackage_path = os.path.join(args.output_dir, "MultilingualTextEncoderProjection.mlpackage")
36+
ct.convert(
37+
text_encoder_projection,
38+
inputs=[ct.TensorType('nlcontextualembeddings_output', shape=(1, MAX_SEQUENCE_LENGTH, EMBED_DIM), dtype=np.float32)],
39+
outputs=[ct.TensorType('encoder_hidden_states', dtype=np.float32)],
40+
minimum_deployment_target=ct.target.macOS14, # NLContextualEmbedding minimum availability build
41+
convert_to='mlprogram',
42+
).save()
43+
44+
# Compile the model and save it under the specified directory
45+
_compile_coreml_model(mlpackage_path, args.output_dir, final_name="MultilingualTextEncoderProjection")
46+
47+
48+
if __name__ == "__main__":
49+
parser = argparse.ArgumentParser()
50+
parser.add_argument(
51+
"--input-path",
52+
help="Path to the torchscript file that contains the projection layer"
53+
)
54+
parser.add_argument(
55+
"--output-dir",
56+
help="Output directory in which the Core ML model should be saved",
57+
)
58+
args = parser.parse_args()
59+
60+
main(args)

0 commit comments

Comments
 (0)