Skip to content

Commit 85d7050

Browse files
authored
Merge pull request cvg#46 from fabio-sim/fabio/feat/cpu-optim
feat: CPU compatibility (cvg#46)
2 parents b1007b3 + bcb594c commit 85d7050

File tree

2 files changed

+109
-3
lines changed

2 files changed

+109
-3
lines changed

lightglue_onnx/optim/fusion_attention_lightglue.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def create_self_attention_node(
211211
) -> NodeProto:
212212
# all_inputs are (B, N, S, H)
213213
if self.enable_packed_qkv:
214+
# Implement Stack via Unsqueeze+Concat
214215
unsqueeze_q_node_name = self.model.create_node_name("Unsqueeze")
215216
unsqueeze_k_node_name = self.model.create_node_name("Unsqueeze")
216217
unsqueeze_v_node_name = self.model.create_node_name("Unsqueeze")
@@ -297,8 +298,98 @@ def create_self_attention_node(
297298
)
298299

299300
return attention_node
300-
else: # Not packed
301-
raise NotImplementedError("Unpacked QKV self-attention not implemented.")
301+
else: # Not packed. (CPU-compatible)
302+
# Transpose nodes: (B, N, S, H) -> (B, S, N, H)
303+
transpose_q_node_name = self.model.create_node_name("Transpose")
304+
transpose_k_node_name = self.model.create_node_name("Transpose")
305+
transpose_v_node_name = self.model.create_node_name("Transpose")
306+
transpose_q_node = helper.make_node(
307+
"Transpose",
308+
inputs=[matmul_q.output[0]],
309+
outputs=[transpose_q_node_name + "_out"],
310+
name=transpose_q_node_name,
311+
perm=[0, 2, 1, 3],
312+
)
313+
self.node_name_to_graph_name[transpose_q_node.name] = self.this_graph_name
314+
transpose_k_node = helper.make_node(
315+
"Transpose",
316+
inputs=[matmul_k.output[0]],
317+
outputs=[transpose_k_node_name + "_out"],
318+
name=transpose_k_node_name,
319+
perm=[0, 2, 1, 3],
320+
)
321+
self.node_name_to_graph_name[transpose_k_node.name] = self.this_graph_name
322+
transpose_v_node = helper.make_node(
323+
"Transpose",
324+
inputs=[matmul_v.output[0]],
325+
outputs=[transpose_v_node_name + "_out"],
326+
name=transpose_v_node_name,
327+
perm=[0, 2, 1, 3],
328+
)
329+
self.node_name_to_graph_name[transpose_v_node.name] = self.this_graph_name
330+
331+
# Reshape nodes: (B, S, N, H) -> (B, S, NH)
332+
reshape_q_node_name = self.model.create_node_name("Reshape")
333+
reshape_k_node_name = self.model.create_node_name("Reshape")
334+
reshape_v_node_name = self.model.create_node_name("Reshape")
335+
for n in (reshape_q_node_name, reshape_k_node_name, reshape_v_node_name):
336+
self.add_initializer(
337+
name=n + "_shape",
338+
data_type=TensorProto.INT64,
339+
dims=[3],
340+
vals=[0, 0, hidden_size],
341+
raw=False,
342+
)
343+
reshape_q_node = helper.make_node(
344+
"Reshape",
345+
inputs=[transpose_q_node_name + "_out", reshape_q_node_name + "_shape"],
346+
outputs=[reshape_q_node_name + "_out"],
347+
name=reshape_q_node_name,
348+
)
349+
self.node_name_to_graph_name[reshape_q_node.name] = self.this_graph_name
350+
reshape_k_node = helper.make_node(
351+
"Reshape",
352+
inputs=[transpose_k_node_name + "_out", reshape_k_node_name + "_shape"],
353+
outputs=[reshape_k_node_name + "_out"],
354+
name=reshape_k_node_name,
355+
)
356+
self.node_name_to_graph_name[reshape_k_node.name] = self.this_graph_name
357+
reshape_v_node = helper.make_node(
358+
"Reshape",
359+
inputs=[transpose_v_node_name + "_out", reshape_v_node_name + "_shape"],
360+
outputs=[reshape_v_node_name + "_out"],
361+
name=reshape_v_node_name,
362+
)
363+
self.node_name_to_graph_name[reshape_v_node.name] = self.this_graph_name
364+
365+
self.nodes_to_add.extend(
366+
[
367+
transpose_q_node,
368+
transpose_k_node,
369+
transpose_v_node,
370+
reshape_q_node,
371+
reshape_k_node,
372+
reshape_v_node,
373+
]
374+
)
375+
376+
attention_inputs = [
377+
reshape_q_node_name + "_out",
378+
reshape_k_node_name + "_out",
379+
reshape_v_node_name + "_out",
380+
]
381+
382+
attention_node_name = self.model.create_node_name("MultiHeadAttention")
383+
attention_node = helper.make_node(
384+
"MultiHeadAttention",
385+
inputs=attention_inputs,
386+
outputs=[output],
387+
name=attention_node_name,
388+
domain="com.microsoft",
389+
num_heads=num_heads,
390+
)
391+
392+
return attention_node
302393

303394
def create_cross_attention_node(
304395
self,

optimize.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from onnx import load_model, save_model
77
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
8+
from onnxruntime.transformers.fusion_options import FusionOptions
89

910
from lightglue_onnx.optim.onnx_model_lightglue import LightGlueOnnxModel
1011

@@ -19,6 +20,9 @@ def parse_args() -> argparse.Namespace:
1920
parser.add_argument(
2021
"-o", "--output", type=str, help="Path to output fused LightGlue ONNX model."
2122
)
23+
parser.add_argument(
24+
"--cpu", action="store_true", help="Whether to optimize for CPU."
25+
)
2226

2327
return parser.parse_args()
2428

@@ -28,12 +32,19 @@ def parse_args() -> argparse.Namespace:
2832
lightglue = load_model(args.input)
2933
optimizer = LightGlueOnnxModel(lightglue, NUM_HEADS, HIDDEN_SIZE)
3034

31-
optimizer.optimize()
35+
options = None
36+
if args.cpu:
37+
options = FusionOptions("unet")
38+
options.enable_packed_qkv = False
39+
40+
optimizer.optimize(options)
3241
optimizer.get_fused_operator_statistics()
3342

3443
output_path = args.output
3544
if output_path is None:
3645
output_path = args.input.replace(".onnx", "_fused.onnx")
46+
if args.cpu:
47+
output_path = output_path.replace(".onnx", "_cpu.onnx")
3748

3849
optimizer.save_model_to_file(output_path)
3950

@@ -42,6 +53,10 @@ def parse_args() -> argparse.Namespace:
4253
output_path,
4354
)
4455

56+
if args.cpu:
57+
print("CPU does not support fp16. Skipping..")
58+
exit()
59+
4560
optimizer.convert_float_to_float16(
4661
keep_io_types=True,
4762
)

0 commit comments

Comments
 (0)