diff --git a/tests/lowering/normalization/test_group_norm.py b/tests/lowering/normalization/test_group_norm.py new file mode 100644 index 000000000..274f9f525 --- /dev/null +++ b/tests/lowering/normalization/test_group_norm.py @@ -0,0 +1,54 @@ +import torch +import torch_ttnn +from torch_ttnn.passes.lowering import target_wrappers +import pytest +import ttnn + +from tests.utils import assert_with_pcc + + +class GroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, num_groups, weight=None, bias=None, eps=1e-05): + return torch.nn.functional.group_norm(input, num_groups, weight, bias, eps) + + +@pytest.mark.parametrize( + "input_shape, num_groups, is_lowered", + [ + [(1, 320, 32, 32), 32, True], + [(1, 1280, 16, 16), 32, True], + [(2, 320, 64, 64), 32, True], + [(1, 1280, 1, 512), 32, True], + # These two cases appeared in stable diffusion v2 and accuracy failed + pytest.param((1, 1280, 8, 8), 32, True, marks=pytest.mark.xfail(reason="see #555")), + pytest.param((1, 2560, 8, 8), 32, True, marks=pytest.mark.xfail(reason="see #555")), + # These four cases appeared in retinanet_resnet50_fpn_v2 and RuntimeError + pytest.param((1, 256, 50, 68), 32, True, marks=pytest.mark.xfail(reason="see #555")), + pytest.param((1, 256, 25, 34), 32, True, marks=pytest.mark.xfail(reason="see #555")), + pytest.param((1, 256, 13, 17), 32, True, marks=pytest.mark.xfail(reason="see #555")), + pytest.param((1, 256, 7, 9), 32, True, marks=pytest.mark.xfail(reason="see #555")), + [(1, 32, 2, 2), 32, False], + [(2, 2, 2, 2), 2, False], + ], +) +def test_group_norm(device, input_shape, num_groups, is_lowered): + m = GroupNormModule() + input = torch.rand(input_shape, dtype=torch.bfloat16) + weight = torch.ones(input_shape[1], dtype=torch.bfloat16) + bias = torch.rand(input_shape[1], dtype=torch.bfloat16) + result_before = m.forward(input, num_groups, weight, bias) + option = torch_ttnn.TorchTtnnOption(device=device) + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, num_groups, weight, bias) + # option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + if is_lowered: + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(target_wrappers.group_norm) == 1 + # Check inference result + assert_with_pcc(result_before, result_after, 0.9997) diff --git a/torch_ttnn/backend.py b/torch_ttnn/backend.py index e5adbdcd9..3ae60ada0 100644 --- a/torch_ttnn/backend.py +++ b/torch_ttnn/backend.py @@ -62,6 +62,7 @@ def register_ttnn_objects(option: TorchTtnnOption): torch.fx.graph._register_custom_builtin("ttnn_uint32", "", ttnn.uint32) torch.fx.graph._register_custom_builtin("ttnn_bfloat16", "", ttnn.bfloat16) + torch.fx.graph._register_custom_builtin("ttnn_bfloat8_b", "", ttnn.DataType.BFLOAT8_B) torch.fx.graph._register_custom_builtin( "ttnn_DRAM_MEMORY_CONFIG", diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 8b238a8a7..5b44b69d9 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -6,6 +6,9 @@ TtnnDevice, TtnnBfloat16, TtnnUint32, + TtnnBfloat8_B, + TtnnDramMemoryConfig, + TtnnL1MemoryConfig, HasValidPageSize, get_dtype, ) @@ -136,10 +139,10 @@ def is_function_call(node) -> bool: target_wrappers.pack_to_tuple, target_wrappers.move_to_host, target_wrappers.conv2d, + target_wrappers.group_norm, ] TTNN_NORM_OPS = [ - ttnn.group_norm, ttnn.layer_norm, ] @@ -280,6 +283,7 @@ class AlignSpecFromTorch: device: Union[None, Type[TtnnDevice], Literal["host"]] layout: Union[None, Type[TtnnTileLayout], Type[TtnnRowMajorLayout]] dtype: Union[None, Type[TtnnBfloat16], Type[TtnnUint32]] + mem_config: Union[None, Type[TtnnDramMemoryConfig], Type[TtnnL1MemoryConfig]] @dataclass(unsafe_hash=True) class AlignSpecToTorch: @@ -292,6 +296,7 @@ class AlignSpecInTtnn: device: Union[None, Type[TtnnDevice], Literal["host"]] layout: Union[None, Type[TtnnTileLayout], Type[TtnnRowMajorLayout]] dtype: Union[None, Type[TtnnBfloat16], Type[TtnnUint32]] + mem_config: Union[None, Type[TtnnDramMemoryConfig], Type[TtnnL1MemoryConfig]] def _align_for_special_layout(self, node, spec, input_site, input_site_type: InputSiteType): if is_target_a_user_of_curr_node(node, ttnn.embedding) and ( @@ -308,6 +313,36 @@ def _align_for_special_layout(self, node, spec, input_site, input_site_type: Inp spec.device = TtnnDevice return spec + def _align_for_group_norm(self, node, spec, input_site, input_site_type: InputSiteType): + if node.target != target_wrappers.group_norm: + return spec + # input tensor + if input_site_type == self.InputSiteType.ARGS and input_site == 0: + spec.device = TtnnDevice + # If set TtnnTileLayout some shape like (1, 320, 32, 32) accuracy will failed, see issue #555 + spec.layout = TtnnRowMajorLayout + spec.dtype = TtnnBfloat16 + spec.mem_config = TtnnDramMemoryConfig + # input mask + if input_site_type == self.InputSiteType.KWARGS and input_site == "input_mask": + spec.device = TtnnDevice + spec.layout = TtnnTileLayout + spec.dtype = TtnnBfloat8_B + spec.mem_config = TtnnDramMemoryConfig + # gamma + if input_site_type == self.InputSiteType.KWARGS and input_site == "weight": + spec.device = TtnnDevice + spec.layout = TtnnRowMajorLayout + spec.dtype = TtnnBfloat16 + spec.mem_config = TtnnDramMemoryConfig + # beta + if input_site_type == self.InputSiteType.KWARGS and input_site == "bias": + spec.device = TtnnDevice + spec.layout = TtnnRowMajorLayout + spec.dtype = TtnnBfloat16 + spec.mem_config = TtnnDramMemoryConfig + return spec + def _reset_to_default_layout(self, input_node, spec): # legalize to the default layout and device if input_node.target in TTNN_LAYOUT_CHANGE_OPS.union( @@ -325,18 +360,20 @@ def _reset_to_default_layout(self, input_node, spec): def _get_align_spec(self, node, input_node, input_site, input_site_type: InputSiteType): if is_torch_to_ttnn(input_node, node): # default set these layout for torch to ttnn - spec = self.AlignSpecFromTorch(input_node, TtnnDevice, TtnnTileLayout, TtnnBfloat16) + spec = self.AlignSpecFromTorch(input_node, TtnnDevice, TtnnTileLayout, TtnnBfloat16, None) spec = self._align_for_special_layout(node, spec, input_site, input_site_type) + spec = self._align_for_group_norm(node, spec, input_site, input_site_type) return spec elif is_ttnn_to_torch(input_node, node): spec = self.AlignSpecToTorch(input_node, "by_node_meta") return spec elif is_ttnn_to_ttnn(input_node, node): # default do nothing between ttnn to ttnn - spec = self.AlignSpecInTtnn(input_node, None, None, None) + spec = self.AlignSpecInTtnn(input_node, None, None, None, None) spec = self._reset_to_default_layout(input_node, spec) spec = self._align_for_special_layout(node, spec, input_site, input_site_type) - if spec.device is None and spec.layout is None and spec.dtype is None: + spec = self._align_for_group_norm(node, spec, input_site, input_site_type) + if spec.device is None and spec.layout is None and spec.dtype is None and spec.mem_config is None: return None return spec return None @@ -380,6 +417,8 @@ def _create_aligned_node(self, spec): kwargs = {"layout": TtnnTileLayout(), "device": TtnnDevice()} if spec.dtype is not None: kwargs["dtype"] = spec.dtype() + if spec.mem_config is not None: + kwargs["memory_config"] = spec.mem_config() aligning_nodes.append(g.call_function(ttnn.from_torch, (spec.input_node,), kwargs)) if spec.layout != TtnnTileLayout: self._change_layout(spec, aligning_nodes) @@ -399,6 +438,8 @@ def _create_aligned_node(self, spec): aligning_nodes.append(call_to_torch_with_meta(g, spec.input_node, spec.dtype)) elif isinstance(spec, self.AlignSpecInTtnn): self._change_layout(spec, aligning_nodes) + if spec.mem_config is not None: + aligning_nodes.append(g.call_function(ttnn.to_memory_config, (aligning_nodes[-1], spec.mem_config()))) return aligning_nodes[-1] def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type: InputSiteType): @@ -423,26 +464,6 @@ def _connect_aligned_node(self, node, aligned_node, input_site, input_site_type: new_arg[tuple_idx] = aligned_node node.update_kwarg(key, tuple(new_arg)) - def _connect_aligned_node_layer_norm( - self, node, input_node, aligned_node, input_site, input_site_type: InputSiteType - ): - # Workaround to output the same layer_norm output - # Before: layer_norm = aten.layer_norm - # getitem = getitem(layer_norm, 0) - # return ((getitem,),) - # After: layer_norm = ttnn.layer_norm - # return (layer_norm,) - # Need to match the tuple in the original return statement - old_args = node.args[0] - if isinstance(old_args, tuple): - new_args = list(old_args) - for idx, old_arg in enumerate(old_args): - if old_arg == input_node: - new_args[idx] = aligned_node - node.update_arg(0, tuple(new_args)) - else: - self._connect_aligned_node(node, aligned_node, input_site, input_site_type) - def align(self, node, input_node, input_site, input_site_type: InputSiteType): # assert input_site_type in ["args", "kwargs", "args_tuple", "kwargs_tuple"] align_spec = self._get_align_spec(node, input_node, input_site, input_site_type) @@ -455,10 +476,7 @@ def align(self, node, input_node, input_site, input_site_type: InputSiteType): with self.graph.inserting_before(node): aligned_node = self._create_aligned_node(align_spec) self.aligned_node_dict[align_spec] = aligned_node - if node.target != ttnn.layer_norm: - self._connect_aligned_node(node, aligned_node, input_site, input_site_type) - else: - self._connect_aligned_node_layer_norm(node, input_node, aligned_node, input_site, input_site_type) + self._connect_aligned_node(node, aligned_node, input_site, input_site_type) return 1 diff --git a/torch_ttnn/passes/lowering/target_wrappers.py b/torch_ttnn/passes/lowering/target_wrappers.py index 333c18617..102d48f1f 100644 --- a/torch_ttnn/passes/lowering/target_wrappers.py +++ b/torch_ttnn/passes/lowering/target_wrappers.py @@ -84,3 +84,30 @@ def conv2d( device=device, ) return output_tensor + + +@torch.fx.wrap +def group_norm( + input_tensor, input_mask, weight, bias, num_groups, epsilon, inplace, grid_size_x, grid_size_y, shard_shape +): + grid_coord = ttnn.CoreCoord(grid_size_x - 1, grid_size_y - 1) + shard_grid = ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), grid_coord)}) + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.COL_MAJOR, False) + sharded_mem_config = ttnn.MemoryConfig( + ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec + ) + input_tensor_sharded = ttnn.to_memory_config(input_tensor, sharded_mem_config) + output_tensor = ttnn.group_norm( + input_tensor_sharded, + num_groups=num_groups, + epsilon=epsilon, + input_mask=input_mask, + weight=weight, + bias=bias, + memory_config=sharded_mem_config, + core_grid=ttnn.CoreGrid(y=grid_size_y, x=grid_size_x), + inplace=inplace, + ) + + output_tensor_l1 = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG) + return output_tensor_l1 diff --git a/torch_ttnn/passes/lowering/to_tt_guard_autogen.py b/torch_ttnn/passes/lowering/to_tt_guard_autogen.py index ea4c1b4bd..6787bb3a5 100644 --- a/torch_ttnn/passes/lowering/to_tt_guard_autogen.py +++ b/torch_ttnn/passes/lowering/to_tt_guard_autogen.py @@ -126,6 +126,81 @@ ["Tensor<[1, 16, 5, 5]> self = ?", "Tensor<[]> other = ?"], ["Tensor<[1, 16, 1, 6]> self = ?", "Tensor<[]> other = ?"], ] + +aten_native_group_norm_default_blocklist = [ + [ + "Tensor<[1, 1280, 8, 8]> input = ?", + "Optional[Tensor]<[1280]> weight = ?", + "Optional[Tensor]<[1280]> bias = ?", + "int N = 1", + "int C = 1280", + "int HxW = 64", + "int group = 32", + "float eps = 1e-05", + ], + [ + "Tensor<[1, 1280, 8, 8]> input = ?", + "Optional[Tensor]<[1280]> weight = ?", + "Optional[Tensor]<[1280]> bias = ?", + "int N = 1", + "int C = 1280", + "int HxW = 64", + "int group = 32", + "float eps = 1e-06", + ], + [ + "Tensor<[1, 2560, 8, 8]> input = ?", + "Optional[Tensor]<[2560]> weight = ?", + "Optional[Tensor]<[2560]> bias = ?", + "int N = 1", + "int C = 2560", + "int HxW = 64", + "int group = 32", + "float eps = 1e-05", + ], + [ + "Tensor<[1, 256, 50, 68]> input = ?", + "Optional[Tensor]<[256]> weight = ?", + "Optional[Tensor]<[256]> bias = ?", + "int N = 1", + "int C = 256", + "int HxW = 3400", + "int group = 32", + "float eps = 1e-05", + ], + [ + "Tensor<[1, 256, 25, 34]> input = ?", + "Optional[Tensor]<[256]> weight = ?", + "Optional[Tensor]<[256]> bias = ?", + "int N = 1", + "int C = 256", + "int HxW = 850", + "int group = 32", + "float eps = 1e-05", + ], + [ + "Tensor<[1, 256, 13, 17]> input = ?", + "Optional[Tensor]<[256]> weight = ?", + "Optional[Tensor]<[256]> bias = ?", + "int N = 1", + "int C = 256", + "int HxW = 221", + "int group = 32", + "float eps = 1e-05", + ], + [ + "Tensor<[1, 256, 7, 9]> input = ?", + "Optional[Tensor]<[256]> weight = ?", + "Optional[Tensor]<[256]> bias = ?", + "int N = 1", + "int C = 256", + "int HxW = 63", + "int group = 32", + "float eps = 1e-05", + ], +] + + aten_native_layer_norm_default_blocklist = [ [ "Tensor<[1, 9, 4096]> input = ?", @@ -1400,6 +1475,7 @@ def guard_aten(blocklist, node): ), torch.ops.aten.zeros_like.default: partial(guard_aten, aten_zeros_like_default_blocklist), torch.ops.aten.div.Tensor: partial(guard_aten, aten_div_Tensor_blocklist), + torch.ops.aten.native_group_norm.default: partial(guard_aten, aten_native_group_norm_default_blocklist), torch.ops.aten.native_layer_norm.default: partial(guard_aten, aten_native_layer_norm_default_blocklist), torch.ops.aten.exp.default: partial(guard_aten, aten_exp_default_blocklist), torch.ops.aten.split.Tensor: partial(guard_aten, aten_split_Tensor_blocklist), @@ -1429,6 +1505,7 @@ def guard_aten(blocklist, node): "torch.ops.aten.zeros_like.default", "torch.ops.aten.div.Tensor", "torch.ops.aten.mul.Tensor", + "torch.ops.aten.native_group_norm.default", "torch.ops.aten.native_layer_norm.default", "torch.ops.aten.sub.Tensor", "torch.ops.aten.exp.default", diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index a096e1cec..eb622bed0 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -423,7 +423,7 @@ def inserting_before(self, node): return self.g.inserting_before(node) -def ReplaceMoreTtManually(gm: torch.fx.GraphModule, use_less_ttnn_op_types: bool) -> torch.fx.GraphModule: +def ReplaceMoreTtManually(gm: torch.fx.GraphModule, device, use_less_ttnn_op_types: bool) -> torch.fx.GraphModule: nodes = list(gm.graph.nodes) for node in nodes: if not can_lowering_to_ttnn(node): @@ -491,17 +491,65 @@ def lower_binary_eltwise(fn, args): # Add additional logic to choose the appropriate memory_config type: DRAM or L1 return g.call_function(target_wrappers.clone, args=(args[0],)) + # reference: tt-metal/tests/ttnn/operations/test_group_norm.py::test_group_norm_with_block_sharded_v2_8x8_grid_tile_layout + if node.target == torch.ops.aten.native_group_norm.default: + if not is_getitem_0_only_user(node): + return None + arg0_shape = get_shape(args[0]) + if len(arg0_shape) != 4: + return None + N, C, H, W = arg0_shape + input_tensor = args[0] + weight_tensor = args[1] + bias_tensor = args[2] + num_groups = args[6] + epsilon = args[7] + inplace = False + grid_size_x = device.compute_with_storage_grid_size().x + grid_size_y = device.compute_with_storage_grid_size().y + shard_shape = N * H * W // grid_size_x, C // grid_size_y + # TODO: Add support for shard_shape = 0 + if shard_shape[0] == 0 or shard_shape[1] == 0: + return None + # input tensor + input_tensor_permute = g.call_function(ttnn.permute, args=(input_tensor, (0, 2, 3, 1))) + input_tensor_reshape = g.call_function(ttnn.reshape, args=(input_tensor_permute, (N, 1, W * H, C))) + # input mask + input_mask_tensor = g.call_function( + ttnn.create_group_norm_input_mask, args=(C, num_groups, grid_size_y) + ) + # gamma/beta + gamma = g.call_function(ttnn.create_group_norm_weight_bias_rm, args=(weight_tensor, C, grid_size_y)) + beta = g.call_function(ttnn.create_group_norm_weight_bias_rm, args=(bias_tensor, C, grid_size_y)) + output_tensor_l1 = g.call_function( + target_wrappers.group_norm, + args=(input_tensor_reshape,), + kwargs={ + "input_mask": input_mask_tensor, + "weight": gamma, + "bias": beta, + "num_groups": num_groups, + "epsilon": epsilon, + "inplace": inplace, + "grid_size_x": grid_size_x, + "grid_size_y": grid_size_y, + "shard_shape": shard_shape, + }, + ) + + output_tensor_l1_reshape = g.call_function(ttnn.reshape, (output_tensor_l1, (N, H, W, C))) + output_tensor_l1_permute = g.call_function(ttnn.permute, (output_tensor_l1_reshape, (0, 3, 1, 2))) + return g.call_function(target_wrappers.pack_to_tuple, (output_tensor_l1_permute,)) + if node.target == torch.ops.aten.native_layer_norm.default: + if not is_getitem_0_only_user(node): + return None new_node = g.call_function( ttnn.layer_norm, args=(args[0],), kwargs={"epsilon": args[4], "weight": args[2], "bias": args[3]}, ) - node.replace_all_uses_with(new_node, delete_user_cb=lambda node: node != new_node) - node_users = list(new_node.users.keys()) - for node_user in node_users: - node_user.replace_all_uses_with(new_node) - return None + return g.call_function(target_wrappers.pack_to_tuple, (new_node,)) if node.target == torch.ops.aten.ones.default: return g.call_function(ttnn.ones, args=args, kwargs={"device": TtnnDevice()}) @@ -1188,6 +1236,6 @@ def call(self, gm: torch.fx.GraphModule): gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform() # Replace patterns manually - gm = ReplaceMoreTtManually(gm, self.use_less_ttnn_op_types) + gm = ReplaceMoreTtManually(gm, self.device, self.use_less_ttnn_op_types) return PassResult(gm, True) diff --git a/torch_ttnn/utils.py b/torch_ttnn/utils.py index 2c570656b..3ff2354db 100644 --- a/torch_ttnn/utils.py +++ b/torch_ttnn/utils.py @@ -62,6 +62,11 @@ def __repr__(self): return f"ttnn_bfloat16" +class TtnnBfloat8_B: + def __repr__(self): + return f"ttnn_bfloat8_b" + + class TtnnDramMemoryConfig: def __repr__(self): return f"ttnn_DRAM_MEMORY_CONFIG"