Skip to content

[TorchToTosa] Refine argmax integer casting and update tests#4489

Open
catcor01 wants to merge 7 commits intollvm:mainfrom
catcor01:argmax_cast
Open

[TorchToTosa] Refine argmax integer casting and update tests#4489
catcor01 wants to merge 7 commits intollvm:mainfrom
catcor01:argmax_cast

Conversation

@catcor01
Copy link
Contributor

@catcor01 catcor01 commented Mar 6, 2026

  • Cast only i32/i64 inputs to f32 before tosa.argmax
  • Update argmax e2e test to exercise integer casting

Change-Id: I5f8847cc7400152c20001905dfacc92af0c1583c

@catcor01
Copy link
Contributor Author

catcor01 commented Mar 9, 2026

@sahas3 @Lallapallooza

Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change @catcor01. The code changes look fine to me but can you clarify why the cast to fp32 is required?

catcor01 added 2 commits March 9, 2026 14:29
Ensure tosa.argmax receives f32 inputs by inserting tosa.cast in argmax
lowering (and min/max-dim argmax paths), and fix the axis attribute type.
Add a bf16 argmax conversion test in basic.mlir to validate the cast+extsi
sequence.

Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I5f8847cc7400152c20001905dfacc92af0c1583c
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I5e240a4dbaaee76d11e5895c82d93b718aabeff2
@catcor01
Copy link
Contributor Author

catcor01 commented Mar 9, 2026

Thanks for the change @catcor01. The code changes look fine to me but can you clarify why the cast to fp32 is required?

The cast is there to satisfy the TOSA type constraints for tosa.argmax. In the base TOSA floating‑point profile, argmax only accepts fp16/fp32 inputs, so int32/int64 isn’t legal so we normalize to fp32 to keep the op spec‑compliant across targets.

EDIT: I have removed reference to bf16 casting now in the patch as it should be targeted through the enabling of the BF15 extension set in TOSA which should be in a separate MR i believe.

Change-Id: I9e4e8c44f77c1802935236f7c01f24896c39d808
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I570e33de7f85806f8101f9a3c0fe193a87cfe776
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
@catcor01 catcor01 changed the title [TorchToTosa] Cast argmax inputs to f32 and add bf16 test [TorchToTosa] Refine argmax integer casting and update tests Mar 10, 2026
Change-Id: I5f7e32e146f4ab42d87445ec3684f52dcf071881
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
@catcor01 catcor01 requested a review from sahas3 March 10, 2026 12:05
Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

I agree with enhancing the bf16 support as part of a separate PR. For hardware targets that can support bf16, there should be no need to cast it to f32.

Change-Id: I742f82270727bf674726e3df8627aa96e88329b4
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
@catcor01 catcor01 requested a review from sahas3 March 12, 2026 07:57
Copy link
Member

@Lallapallooza Lallapallooza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM: small nit

return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src);
}

Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename ensureF32Input?

The current name implies "always convert to f32", but the function actually leaves floats and i8 unchanged and only casts some integer types. Something like legalizeArgMaxInputType or normalizeArgMaxInputForTosa may be clearer.

Change-Id: I0d355828b18cc1fc63d93b9bed8d2ebf4b03c305
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
@catcor01 catcor01 requested a review from Lallapallooza March 12, 2026 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants