-
Notifications
You must be signed in to change notification settings - Fork 70
Fix NestedTensor amin/amax/argmin operations for integer dtypes #2764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes integer type support in amin/amax/argmin operations by addressing invalid type conversions between double-float padding values and integer tensor values. The fix clamps padding values to the valid min/max range of the target dtype.
Changes:
- Replaced direct static_cast of padding_value with a call to
_get_padding_valuehelper function that properly handles type conversion for both floating-point and integer types
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/ATen/native/nested/xpu/sycl/NestedTensorTransformerFunctionKernels.cpp
Show resolved
Hide resolved
CuiYifeng
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
|
Please update PR title since the fixing is for NestedTensor, thanks |
guangyey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
|
New Passing Known Issues in #2703: |
72804b1 to
d0c69e4
Compare
Performance outliers, please check!
|
Fix issues: #2703
The issue is caused by an invalid type conversion when the input tensor values are integers while the padding value is of double-float type, resulting in incorrect casting behavior. This is fixed by clamping the padding value to the valid min/max range of the target input dtype.