[TorchToTosa] Refine argmax integer casting and update tests#4489
[TorchToTosa] Refine argmax integer casting and update tests#4489
Conversation
Ensure tosa.argmax receives f32 inputs by inserting tosa.cast in argmax lowering (and min/max-dim argmax paths), and fix the axis attribute type. Add a bf16 argmax conversion test in basic.mlir to validate the cast+extsi sequence. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I5f8847cc7400152c20001905dfacc92af0c1583c
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I5e240a4dbaaee76d11e5895c82d93b718aabeff2
The cast is there to satisfy the TOSA type constraints for tosa.argmax. In the base TOSA floating‑point profile, argmax only accepts fp16/fp32 inputs, so int32/int64 isn’t legal so we normalize to fp32 to keep the op spec‑compliant across targets. EDIT: I have removed reference to bf16 casting now in the patch as it should be targeted through the enabling of the BF15 extension set in TOSA which should be in a separate MR i believe. |
Change-Id: I9e4e8c44f77c1802935236f7c01f24896c39d808 Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I570e33de7f85806f8101f9a3c0fe193a87cfe776 Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I5f7e32e146f4ab42d87445ec3684f52dcf071881 Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
sahas3
left a comment
There was a problem hiding this comment.
Thanks for the clarification.
I agree with enhancing the bf16 support as part of a separate PR. For hardware targets that can support bf16, there should be no need to cast it to f32.
Change-Id: I742f82270727bf674726e3df8627aa96e88329b4 Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
| return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src); | ||
| } | ||
|
|
||
| Value ensureF32Input(PatternRewriter &rewriter, Operation *op, Value input) { |
There was a problem hiding this comment.
Could we rename ensureF32Input?
The current name implies "always convert to f32", but the function actually leaves floats and i8 unchanged and only casts some integer types. Something like legalizeArgMaxInputType or normalizeArgMaxInputForTosa may be clearer.
Change-Id: I0d355828b18cc1fc63d93b9bed8d2ebf4b03c305 Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I5f8847cc7400152c20001905dfacc92af0c1583c