-
Notifications
You must be signed in to change notification settings - Fork 7
Try lowering aten.nll_loss_forward
to ttnn.moreh_nll_loss
#676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jdh8
wants to merge
3
commits into
main
Choose a base branch
from
jdh8/nll_loss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
from tests.utils import assert_with_pcc | ||
|
||
|
||
class NllLossModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, *args, **kwargs): | ||
return torch.nn.functional.nll_loss(*args, **kwargs) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, has_weight, reduction, ignore_index", | ||
( | ||
((19, 256008), True, "mean", -100), | ||
((19, 256008), False, "mean", -100), | ||
), | ||
) | ||
def test_nll_loss(device, input_shape, has_weight, reduction, ignore_index): | ||
module = NllLossModule() | ||
batch, channels = input_shape | ||
input = torch.rand((batch, channels), dtype=torch.bfloat16) | ||
input = torch.nn.functional.log_softmax(input, dim=1) | ||
|
||
target = torch.randint(0, channels, (batch,), dtype=torch.long) | ||
weight = torch.rand((channels,), dtype=torch.bfloat16) if has_weight else None | ||
result_before = module.forward(input, target, weight, reduction=reduction, ignore_index=ignore_index) | ||
|
||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
|
||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
module = torch.compile(module, backend=torch_ttnn.backend, options=option) | ||
|
||
result_after = module.forward(input, target, weight, reduction=reduction, ignore_index=ignore_index) | ||
print(option._out_fx_graphs[0]) | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = [node.target for node in option._out_fx_graphs[0].nodes] | ||
assert torch.ops.aten.nll_loss_forward.default not in nodes | ||
assert nodes.count(ttnn.operations.moreh.nll_loss) == 1 | ||
|
||
# Check inference result | ||
assert_with_pcc(result_before, result_after, pcc=0.99) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -355,6 +355,16 @@ def __init__(self, target, args, kwargs): | |
|
||
return self.call_function_prop_meta(ttnn.reshape, (tensor, size)) | ||
|
||
if target == torch.ops.aten.nll_loss_forward.default: | ||
input, target, weight, reduction, ignore_index = args | ||
args = input, target, ("none", "mean", "sum")[reduction] | ||
kwargs = { | ||
"divisor_tensor": torch.tensor([0], dtype=get_dtype(input)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"weight_tensor": weight, | ||
"ignore_index": ignore_index, | ||
} | ||
return self.call_function_prop_meta(ttnn.moreh_nll_loss, args, kwargs) | ||
|
||
return self.call_function_prop_meta(target, args, kwargs) | ||
|
||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Using
ttnn.moreh_*
wrappers work!