Skip to content

Commit fd227bb

Browse files
committed
update tests
1 parent ebade48 commit fd227bb

File tree

3 files changed

+166
-179
lines changed

3 files changed

+166
-179
lines changed

onnxruntime/python/tools/transformers/fusion_group_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
8484
instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
8585
if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
8686
return
87+
num_groups = int(instance_norm_scale.shape[0])
8788

8889
instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
8990
if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
@@ -156,7 +157,8 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
156157
)
157158

158159
new_node.attribute.extend(instance_norm.attribute)
159-
new_node.attribute.extend([helper.make_attribute("groups", 32)])
160+
161+
new_node.attribute.extend([helper.make_attribute("groups", num_groups)])
160162
new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)])
161163

162164
if not self.channels_last:

onnxruntime/python/tools/transformers/onnx_model_clip.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_fused_operator_statistics(self):
2727
"Gelu",
2828
"LayerNormalization",
2929
"QuickGelu",
30+
"BiasGelu",
3031
"SkipLayerNormalization",
3132
]
3233
for op in ops:

0 commit comments

Comments
 (0)