-
Notifications
You must be signed in to change notification settings - Fork 25
Broadcast aten.maximum.default and aten.minimum.default inputs
#586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
aten.maximum.default inputs
dec0d74 to
a6822d3
Compare
aten.maximum.default inputsaten.maximum.default and aten.minimum.default inputs
| if len(args) > 1: | ||
| other_tensor = args[1] | ||
| else: | ||
| other_tensor = kwargs["other"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if len(args) > 1: | |
| other_tensor = args[1] | |
| else: | |
| other_tensor = kwargs["other"] | |
| other_tensor = None # Explicitly initialize to a default value. | |
| if len(args) > 1: | |
| other_tensor = args[1] | |
| else: | |
| other_tensor = kwargs["other"] |
| if new_shape is not None or new_dtype is not None: | ||
| shape = new_shape if new_shape is not None else new_node.meta["val"].size() | ||
| dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype | ||
| fake_mode = FakeTensorMode() | ||
| fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) | ||
| new_node.meta["val"] = fake_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify the need for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call_function is to create a new_node and is assigned meta from current_node which is being traversed, but new_node's shape & dtype may not same with cur_node (for example, new_node.target is aten.expand from current_node and then shape change), so there give the option for user to specify the correct shape & dtype
| if input_tensor_shape == torch.Size([]): | ||
| input_tensor_shape = torch.Size([1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify the need for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below code cannot handle [], and the result of expand [] and [1] is the same, so I see [] as [1]
ayerofieiev-tt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hesitate to move forward with this.
For this one, I feel it is better to wait for the proper bcasting fix in TT-NN.
Change looks fairly intrusive to me
ok, then I cancel this PR and just wait tt-metal support tenstorrent/tt-metal#12852 |
Pull request was converted to draft
Ticket
#592
Problem description
aten.maximumhave some broadcasting issue of tenstorrent/tt-metal#12852 , I do the workaround by usingaten.expandto broadcast its inputs beforehand, andaten.expandmay lowered to ttnn laterWhat's changed