Bug Report:
This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via run_hlo_module --input_format=hlo, and the conversion via hlo-translate --hlo-to-mlir also succeeds, but running the translated StableHLO via run_hlo_module --input_format=stablehlo fails with:
INTERNAL: during context [Unknown]: The async-done expects the shape of output to match the async shape at index {1} ((f32[]) vs f32[]).
Environment
- XLA commit:
5ce7908a2d32a9f91fd99380435cda1b645c8cc7
- CPU:
Intel(R) Core(TM) i9-14900HX
- GPU:
NVIDIA GeForce RTX 4060 Laptop GPU
- CUDA Driver:
580.126.09
run_hlo_module (HLO) — Success
HLO:
HloModule AsyncCustomCallInstructionsWithoutSideEffect, entry_computation_layout={()->()}
ENTRY AsyncCustomCallInstructionsWithoutSideEffect {
ROOT tuple = () tuple()
custom-call.cloned.1.call-start = ((), f32[], u32[]) custom-call-start(), custom_call_target="foo"
custom-call.cloned.1.call-done = f32[] custom-call-done(custom-call.cloned.1.call-start)
}
Execution Command:
run_hlo_module \
--platform=CPU \
--reference_platform= \
--input_format=hlo \
AsyncCustomCallInstructionsWithoutSideEffect_b4a0c8f4_12.hlo
Output:
** Running AsyncCustomCallInstructionsWithoutSideEffect_b4a0c8f4_12.hlo**
Running HLO module with runner Host...
... compiled and ran in 0.00704402s.
Skipping reference runner
run_hlo_module (StableHLO) — FAIL
Translation Command:
hlo-translate \
--hlo-to-mlir \
AsyncCustomCallInstructionsWithoutSideEffect_b4a0c8f4_12.hlo \
-o \
AsyncCustomCallInstructionsWithoutSideEffect_b4a0c8f4_12.mlir
IR After Translation:
module @AsyncCustomCallInstructionsWithoutSideEffect attributes {mhlo.cross_program_prefetches = [], mhlo.input_output_alias = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} {
func.func private @async_wrapped() -> tensor<f32> attributes {execution_thread = "main"} {
%0 = stablehlo.custom_call @foo() {backend_config = ""} : () -> tensor<f32>
return %0 : tensor<f32>
}
func.func @main() {
%0 = "mhlo.async_start"() <{called_computation = @async_wrapped, execution_thread = "main"}> {xla_shape = "((), f32[], u32[])"} : () -> !mhlo.async_bundle<tuple<>, tensor<f32>, tensor<ui32>>
%1 = "mhlo.async_done"(%0) {called_computation = @async_wrapped, execution_thread = "main"} : (!mhlo.async_bundle<tuple<>, tensor<f32>, tensor<ui32>>) -> tensor<f32>
return
}
}
Execution Command:
run_hlo_module \
--platform=CPU \
--reference_platform= \
--input_format=stablehlo \
AsyncCustomCallInstructionsWithoutSideEffect_b4a0c8f4_12.mlir
Output:
** Running AsyncCustomCallInstructionsWithoutSideEffect_b4a0c8f4_12.mlir**
INTERNAL: during context [Unknown]: The async-done expects the shape of output to match the async shape at index {1} ((f32[]) vs f32[]).
Contact
- Email:
ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com
Bug Report:
This bug is triggered by the HLO-to-StableHLO conversion: the HLO module runs successfully via
run_hlo_module --input_format=hlo, and the conversion viahlo-translate --hlo-to-mliralso succeeds, but running the translated StableHLO viarun_hlo_module --input_format=stablehlofails with:Environment
5ce7908a2d32a9f91fd99380435cda1b645c8cc7Intel(R) Core(TM) i9-14900HXNVIDIA GeForce RTX 4060 Laptop GPU580.126.09run_hlo_module (HLO) — Success
HLO:
Execution Command:
Output:
run_hlo_module (StableHLO) — FAIL
Translation Command:
IR After Translation:
Execution Command:
Output:
Contact
ch395@njit.edu, zhihao.yao@njit.edu, benquike@gmail.com