Skip to content

Lowering aten.native_group_norm.default #556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions tests/lowering/normalization/test_group_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch_ttnn
from torch_ttnn.passes.lowering import target_wrappers
import pytest
import ttnn

from tests.utils import assert_with_pcc


class GroupNormModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, num_groups, weight=None, bias=None, eps=1e-05):
return torch.nn.functional.group_norm(input, num_groups, weight, bias, eps)


@pytest.mark.parametrize(
"input_shape, num_groups, is_lowered",
[
[(1, 320, 32, 32), 32, True],
[(1, 1280, 16, 16), 32, True],
[(2, 320, 64, 64), 32, True],
[(1, 1280, 1, 512), 32, True],
# These two cases appeared in stable diffusion v2 and accuracy failed
pytest.param((1, 1280, 8, 8), 32, True, marks=pytest.mark.xfail(reason="see #555")),
pytest.param((1, 2560, 8, 8), 32, True, marks=pytest.mark.xfail(reason="see #555")),
# These four cases appeared in retinanet_resnet50_fpn_v2 and RuntimeError
pytest.param((1, 256, 50, 68), 32, True, marks=pytest.mark.xfail(reason="see #555")),
pytest.param((1, 256, 25, 34), 32, True, marks=pytest.mark.xfail(reason="see #555")),
pytest.param((1, 256, 13, 17), 32, True, marks=pytest.mark.xfail(reason="see #555")),
pytest.param((1, 256, 7, 9), 32, True, marks=pytest.mark.xfail(reason="see #555")),
[(1, 32, 2, 2), 32, False],
[(2, 2, 2, 2), 2, False],
],
)
def test_group_norm(device, input_shape, num_groups, is_lowered):
m = GroupNormModule()
input = torch.rand(input_shape, dtype=torch.bfloat16)
weight = torch.ones(input_shape[1], dtype=torch.bfloat16)
bias = torch.rand(input_shape[1], dtype=torch.bfloat16)
result_before = m.forward(input, num_groups, weight, bias)
option = torch_ttnn.TorchTtnnOption(device=device)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input, num_groups, weight, bias)
# option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
if is_lowered:
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(target_wrappers.group_norm) == 1
# Check inference result
assert_with_pcc(result_before, result_after, 0.9997)
1 change: 1 addition & 0 deletions torch_ttnn/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def register_ttnn_objects(option: TorchTtnnOption):

torch.fx.graph._register_custom_builtin("ttnn_uint32", "", ttnn.uint32)
torch.fx.graph._register_custom_builtin("ttnn_bfloat16", "", ttnn.bfloat16)
torch.fx.graph._register_custom_builtin("ttnn_bfloat8_b", "", ttnn.DataType.BFLOAT8_B)

torch.fx.graph._register_custom_builtin(
"ttnn_DRAM_MEMORY_CONFIG",
Expand Down
74 changes: 46 additions & 28 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
TtnnDevice,
TtnnBfloat16,
TtnnUint32,
TtnnBfloat8_B,
TtnnDramMemoryConfig,
TtnnL1MemoryConfig,
HasValidPageSize,
get_dtype,
)
Expand Down Expand Up @@ -136,10 +139,10 @@ def is_function_call(node) -> bool:
target_wrappers.pack_to_tuple,
target_wrappers.move_to_host,
target_wrappers.conv2d,
target_wrappers.group_norm,
]

TTNN_NORM_OPS = [
ttnn.group_norm,
ttnn.layer_norm,
]

Expand Down Expand Up @@ -280,6 +283,7 @@ class AlignSpecFromTorch:
device: Union[None, Type[TtnnDevice], Literal["host"]]
layout: Union[None, Type[TtnnTileLayout], Type[TtnnRowMajorLayout]]
dtype: Union[None, Type[TtnnBfloat16], Type[TtnnUint32]]
mem_config: Union[None, Type[TtnnDramMemoryConfig], Type[TtnnL1MemoryConfig]]

@dataclass(unsafe_hash=True)
class AlignSpecToTorch:
Expand All @@ -292,6 +296,7 @@ class AlignSpecInTtnn:
device: Union[None, Type[TtnnDevice], Literal["host"]]
layout: Union[None, Type[TtnnTileLayout], Type[TtnnRowMajorLayout]]
dtype: Union[None, Type[TtnnBfloat16], Type[TtnnUint32]]
mem_config: Union[None, Type[TtnnDramMemoryConfig], Type[TtnnL1MemoryConfig]]

def _align_for_special_layout(self, node, spec, input_site, input_site_type: InputSiteType):
if is_target_a_user_of_curr_node(node, ttnn.embedding) and (
Expand All @@ -308,6 +313,36 @@ def _align_for_special_layout(self, node, spec, input_site, input_site_type: Inp
spec.device = TtnnDevice
return spec

def _align_for_group_norm(self, node, spec, input_site, input_site_type: InputSiteType):
if node.target != target_wrappers.group_norm:
return spec
# input tensor
if input_site_type == self.InputSiteType.ARGS and input_site == 0:
spec.device = TtnnDevice
# If set TtnnTileLayout some shape like (1, 320, 32, 32) accuracy will failed, see issue #555
spec.layout = TtnnRowMajorLayout
spec.dtype = TtnnBfloat16
spec.mem_config = TtnnDramMemoryConfig
# input mask
if input_site_type == self.InputSiteType.KWARGS and input_site == "input_mask":
spec.device = TtnnDevice
spec.layout = TtnnTileLayout
spec.dtype = TtnnBfloat8_B
spec.mem_config = TtnnDramMemoryConfig
# gamma
if input_site_type == self.InputSiteType.KWARGS and input_site == "weight":
spec.device = TtnnDevice
spec.layout = TtnnRowMajorLayout
spec.dtype = TtnnBfloat16
spec.mem_config = TtnnDramMemoryConfig
# beta
if input_site_type == self.InputSiteType.KWARGS and input_site == "bias":
spec.device = TtnnDevice
spec.layout = TtnnRowMajorLayout
spec.dtype = TtnnBfloat16
spec.mem_config = TtnnDramMemoryConfig
return spec

def _reset_to_default_layout(self, input_node, spec):
# legalize to the default layout and device
if input_node.target in TTNN_LAYOUT_CHANGE_OPS.union(
Expand All @@ -325,18 +360,20 @@ def _reset_to_default_layout(self, input_node, spec):
def _get_align_spec(self, node, input_node, input_site, input_site_type: InputSiteType):
if is_torch_to_ttnn(input_node, node):
# default set these layout for torch to ttnn
spec = self.AlignSpecFromTorch(input_node, TtnnDevice, TtnnTileLayout, TtnnBfloat16)
spec = self.AlignSpecFromTorch(input_node, TtnnDevice, TtnnTileLayout, TtnnBfloat16, None)
spec = self._align_for_special_layout(node, spec, input_site, input_site_type)
spec = self._align_for_group_norm(node, spec, input_site, input_site_type)
return spec
elif is_ttnn_to_torch(input_node, node):
spec = self.AlignSpecToTorch(input_node, "by_node_meta")
return spec
elif is_ttnn_to_ttnn(input_node, node):
# default do nothing between ttnn to ttnn
spec = self.AlignSpecInTtnn(input_node, None, None, None)
spec = self.AlignSpecInTtnn(input_node, None, None, None, None)
spec = self._reset_to_default_layout(input_node, spec)
spec = self._align_for_special_layout(node, spec, input_site, input_site_type)
if spec.device is None and spec.layout is None and spec.dtype is None:
spec = self._align_for_group_norm(node, spec, input_site, input_site_type)
if spec.device is None and spec.layout is None and spec.dtype is None and spec.mem_config is None:
return None
return spec
return None
Expand Down Expand Up @@ -380,6 +417,8 @@ def _create_aligned_node(self, spec):
kwargs = {"layout": TtnnTileLayout(), "device": TtnnDevice()}
if spec.dtype is not None:
kwargs["dtype"] = spec.dtype()
if spec.mem_config is not None:
kwargs["memory_config"] = spec.mem_config()
aligning_nodes.append(g.call_function(ttnn.from_torch, (spec.input_node,), kwargs))
if spec.layout != TtnnTileLayout:
self._change_layout(spec, aligning_nodes)
Expand All @@ -399,6 +438,8 @@ def _create_aligned_node(self, spec):
aligning_nodes.append(call_to_torch_with_meta(g, spec.input_node, spec.dtype))
elif isinstance(spec, self.AlignSpecInTtnn):
self._change_layout(spec, aligning_nodes)
if spec.mem_config is not None:
aligning_nodes.append(g.call_function(ttnn.to_memory_config, (aligning_nodes[-1], spec.mem_config())))
return aligning_nodes[-1]

def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type: InputSiteType):
Expand All @@ -423,26 +464,6 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type:
new_arg[tuple_idx] = aligned_node
node.update_kwarg(key, tuple(new_arg))

def _connect_aligned_node_layer_norm(
self, node, input_node, aligned_node, input_site, input_site_type: InputSiteType
):
# Workaround to output the same layer_norm output
# Before: layer_norm = aten.layer_norm
# getitem = getitem(layer_norm, 0)
# return ((getitem,),)
# After: layer_norm = ttnn.layer_norm
# return (layer_norm,)
# Need to match the tuple in the original return statement
old_args = node.args[0]
if isinstance(old_args, tuple):
new_args = list(old_args)
for idx, old_arg in enumerate(old_args):
if old_arg == input_node:
new_args[idx] = aligned_node
node.update_arg(0, tuple(new_args))
else:
self._connect_aligned_node(node, aligned_node, input_site, input_site_type)

def align(self, node, input_node, input_site, input_site_type: InputSiteType):
# assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"]
align_spec = self._get_align_spec(node, input_node, input_site, input_site_type)
Expand All @@ -455,10 +476,7 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType):
with self.graph.inserting_before(node):
aligned_node = self._create_aligned_node(align_spec)
self.aligned_node_dict[align_spec] = aligned_node
if node.target != ttnn.layer_norm:
self._connect_aligned_node(node, aligned_node, input_site, input_site_type)
else:
self._connect_aligned_node_layer_norm(node, input_node, aligned_node, input_site, input_site_type)
self._connect_aligned_node(node, aligned_node, input_site, input_site_type)
return 1


Expand Down
27 changes: 27 additions & 0 deletions torch_ttnn/passes/lowering/target_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,30 @@ def conv2d(
device=device,
)
return output_tensor


@torch.fx.wrap
def group_norm(
input_tensor, input_mask, weight, bias, num_groups, epsilon, inplace, grid_size_x, grid_size_y, shard_shape
):
grid_coord = ttnn.CoreCoord(grid_size_x - 1, grid_size_y - 1)
shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)})
shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False)
sharded_mem_config = ttnn.MemoryConfig(
ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec
)
input_tensor_sharded = ttnn.to_memory_config(input_tensor, sharded_mem_config)
output_tensor = ttnn.group_norm(
input_tensor_sharded,
num_groups=num_groups,
epsilon=epsilon,
input_mask=input_mask,
weight=weight,
bias=bias,
memory_config=sharded_mem_config,
core_grid=ttnn.CoreGrid(y=grid_size_y, x=grid_size_x),
inplace=inplace,
)

output_tensor_l1 = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG)
return output_tensor_l1
77 changes: 77 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,81 @@
["Tensor<[1, 16, 5, 5]> self = ?", "Tensor<[]> other = ?"],
["Tensor<[1, 16, 1, 6]> self = ?", "Tensor<[]> other = ?"],
]

aten_native_group_norm_default_blocklist = [
[
"Tensor<[1, 1280, 8, 8]> input = ?",
"Optional[Tensor]<[1280]> weight = ?",
"Optional[Tensor]<[1280]> bias = ?",
"int N = 1",
"int C = 1280",
"int HxW = 64",
"int group = 32",
"float eps = 1e-05",
],
[
"Tensor<[1, 1280, 8, 8]> input = ?",
"Optional[Tensor]<[1280]> weight = ?",
"Optional[Tensor]<[1280]> bias = ?",
"int N = 1",
"int C = 1280",
"int HxW = 64",
"int group = 32",
"float eps = 1e-06",
],
[
"Tensor<[1, 2560, 8, 8]> input = ?",
"Optional[Tensor]<[2560]> weight = ?",
"Optional[Tensor]<[2560]> bias = ?",
"int N = 1",
"int C = 2560",
"int HxW = 64",
"int group = 32",
"float eps = 1e-05",
],
[
"Tensor<[1, 256, 50, 68]> input = ?",
"Optional[Tensor]<[256]> weight = ?",
"Optional[Tensor]<[256]> bias = ?",
"int N = 1",
"int C = 256",
"int HxW = 3400",
"int group = 32",
"float eps = 1e-05",
],
[
"Tensor<[1, 256, 25, 34]> input = ?",
"Optional[Tensor]<[256]> weight = ?",
"Optional[Tensor]<[256]> bias = ?",
"int N = 1",
"int C = 256",
"int HxW = 850",
"int group = 32",
"float eps = 1e-05",
],
[
"Tensor<[1, 256, 13, 17]> input = ?",
"Optional[Tensor]<[256]> weight = ?",
"Optional[Tensor]<[256]> bias = ?",
"int N = 1",
"int C = 256",
"int HxW = 221",
"int group = 32",
"float eps = 1e-05",
],
[
"Tensor<[1, 256, 7, 9]> input = ?",
"Optional[Tensor]<[256]> weight = ?",
"Optional[Tensor]<[256]> bias = ?",
"int N = 1",
"int C = 256",
"int HxW = 63",
"int group = 32",
"float eps = 1e-05",
],
]


aten_native_layer_norm_default_blocklist = [
[
"Tensor<[1, 9, 4096]> input = ?",
Expand Down Expand Up @@ -1400,6 +1475,7 @@ def guard_aten(blocklist, node):
),
torch.ops.aten.zeros_like.default: partial(guard_aten, aten_zeros_like_default_blocklist),
torch.ops.aten.div.Tensor: partial(guard_aten, aten_div_Tensor_blocklist),
torch.ops.aten.native_group_norm.default: partial(guard_aten, aten_native_group_norm_default_blocklist),
torch.ops.aten.native_layer_norm.default: partial(guard_aten, aten_native_layer_norm_default_blocklist),
torch.ops.aten.exp.default: partial(guard_aten, aten_exp_default_blocklist),
torch.ops.aten.split.Tensor: partial(guard_aten, aten_split_Tensor_blocklist),
Expand Down Expand Up @@ -1429,6 +1505,7 @@ def guard_aten(blocklist, node):
"torch.ops.aten.zeros_like.default",
"torch.ops.aten.div.Tensor",
"torch.ops.aten.mul.Tensor",
"torch.ops.aten.native_group_norm.default",
"torch.ops.aten.native_layer_norm.default",
"torch.ops.aten.sub.Tensor",
"torch.ops.aten.exp.default",
Expand Down
Loading
Loading