Skip to content

Try lowering aten.nll_loss_forward to ttnn.moreh_nll_loss #676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/lowering/misc/test_nll_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import assert_with_pcc


class NllLossModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.nn.functional.nll_loss(*args, **kwargs)


@pytest.mark.parametrize(
"input_shape, has_weight, reduction, ignore_index",
(
((19, 256008), True, "mean", -100),
((19, 256008), False, "mean", -100),
),
)
def test_nll_loss(device, input_shape, has_weight, reduction, ignore_index):
module = NllLossModule()
batch, channels = input_shape
input = torch.rand((batch, channels), dtype=torch.bfloat16)
input = torch.nn.functional.log_softmax(input, dim=1)

target = torch.randint(0, channels, (batch,), dtype=torch.long)
weight = torch.rand((channels,), dtype=torch.bfloat16) if has_weight else None
result_before = module.forward(input, target, weight, reduction=reduction, ignore_index=ignore_index)

option = torch_ttnn.TorchTtnnOption(device=device)

# The compilation is lazy, so we need to run forward once to trigger the compilation
module = torch.compile(module, backend=torch_ttnn.backend, options=option)

result_after = module.forward(input, target, weight, reduction=reduction, ignore_index=ignore_index)
print(option._out_fx_graphs[0])

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert torch.ops.aten.nll_loss_forward.default not in nodes
assert nodes.count(ttnn.operations.moreh.nll_loss) == 1

# Check inference result
assert_with_pcc(result_before, result_after, pcc=0.99)
2 changes: 1 addition & 1 deletion torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,12 @@ def is_tt_compute(node) -> bool:
ttnn.zeros_like,
ttnn.mean,
ttnn.moreh_cumsum,
ttnn.moreh_nll_loss,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Using ttnn.moreh_* wrappers work!

ttnn.clip,
ttnn.squeeze,
ttnn.full,
ttnn.as_tensor,
ttnn.expand,
ttnn.moreh_cumsum,
ttnn.sum,
ttnn.typecast,
ttnn.argmax,
Expand Down
10 changes: 10 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,16 @@ def __init__(self, target, args, kwargs):

return self.call_function_prop_meta(ttnn.reshape, (tensor, size))

if target == torch.ops.aten.nll_loss_forward.default:
input, target, weight, reduction, ignore_index = args
args = input, target, ("none", "mean", "sum")[reduction]
kwargs = {
"divisor_tensor": torch.tensor([0], dtype=get_dtype(input)),
Copy link
Contributor Author

@jdh8 jdh8 Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Isn't the default divisor 1?
  2. We arrive at a segfault (even after I tried 1 instead of 0)
tests/lowering/misc/test_nll_loss.py Fatal Python error: Segmentation fault

Thread 0x00007f1a8a7e4700 (most recent call first):
  File "/usr/lib/python3.8/threading.py", line 306 in wait
  File "/usr/lib/python3.8/threading.py", line 558 in wait
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/tqdm/_monitor.py", line 60 in run
  File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap

Current thread 0x00007f1ba3a19740 (most recent call first):
  File "/home/jdh8/tt-metal/ttnn/ttnn/decorators.py", line 329 in __call__
  File "/home/jdh8/tt-metal/ttnn/ttnn/operations/core.py", line 233 in from_torch
  File "/home/jdh8/tt-metal/ttnn/ttnn/decorators.py", line 329 in __call__
  File "<eval_with_key>.15", line 8 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph_module.py", line 304 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/fx/graph_module.py", line 738 in call_wrapped
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118 in rng_functionalization_wrapper
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105 in call_func_at_runtime_with_args
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94 in runtime_wrapper
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81 in g
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 901 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17 in inner
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py", line 13 in forward
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520 in _call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511 in _wrapped_call_impl
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 489 in _fn
  File "/home/jdh8/pytorch2.0_ttnn/tests/lowering/misc/test_nll_loss.py", line 39 in test_nll_loss
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/python.py", line 1789 in runtest
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 339 in from_call
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 324 in _main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_callers.py", line 103 in _multicall
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_manager.py", line 120 in _hookexec
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/pluggy/_hooks.py", line 513 in __call__
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/home/jdh8/tt-metal/python_env/lib/python3.8/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/home/jdh8/tt-metal/python_env/bin/pytest", line 8 in <module>
Segmentation fault (core dumped)

"weight_tensor": weight,
"ignore_index": ignore_index,
}
return self.call_function_prop_meta(ttnn.moreh_nll_loss, args, kwargs)

return self.call_function_prop_meta(target, args, kwargs)


Expand Down
Loading