Open
Description
See jax-ml/jax#7654
We should deduplicate reducers when converting from MHLO to HLO. e.g. compare:
In [1]: import jax
In [2]: import jax.numpy as jnp
In [3]: def f(x, y): return jnp.sum(x) + jnp.sum(y)
In [4]: print(jax.jit(f).lower(jnp.arange(10), jnp.arange(15)).compiler_ir())
module @jit_f.2 {
func.func public @main(%arg0: tensor<10xi32>, %arg1: tensor<15xi32>) -> tensor<i32> {
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.reduce(%arg0 init: %0) across dimensions = [0] : (tensor<10xi32>, tensor<i32>) -> tensor<i32>
reducer(%arg2: tensor<i32>, %arg3: tensor<i32>) {
%5 = mhlo.add %arg2, %arg3 : tensor<i32>
"mhlo.return"(%5) : (tensor<i32>) -> ()
}
%2 = mhlo.constant dense<0> : tensor<i32>
%3 = mhlo.reduce(%arg1 init: %2) across dimensions = [0] : (tensor<15xi32>, tensor<i32>) -> tensor<i32>
reducer(%arg2: tensor<i32>, %arg3: tensor<i32>) {
%5 = mhlo.add %arg2, %arg3 : tensor<i32>
"mhlo.return"(%5) : (tensor<i32>) -> ()
}
%4 = mhlo.add %1, %3 : tensor<i32>
return %4 : tensor<i32>
}
}
and
In [6]: print(jax.jit(f).lower(jnp.arange(10), jnp.arange(15)).compiler_ir(dialect="hlo").as_hlo_text())
HloModule jit_f.4, entry_computation_layout={(s32[10]{0},s32[15]{0})->s32[]}
region_0.4 {
Arg_0.5 = s32[] parameter(0)
Arg_1.6 = s32[] parameter(1)
ROOT add.7 = s32[] add(Arg_0.5, Arg_1.6)
}
region_1.9 {
Arg_0.10 = s32[] parameter(0)
Arg_1.11 = s32[] parameter(1)
ROOT add.12 = s32[] add(Arg_0.10, Arg_1.11)
}
ENTRY main.15 {
Arg_0.1 = s32[10]{0} parameter(0)
constant.3 = s32[] constant(0)
reduce.8 = s32[] reduce(Arg_0.1, constant.3), dimensions={0}, to_apply=region_0.4
Arg_1.2 = s32[15]{0} parameter(1)
reduce.13 = s32[] reduce(Arg_1.2, constant.3), dimensions={0}, to_apply=region_1.9
ROOT add.14 = s32[] add(reduce.8, reduce.13)
}
It would be great to merge region_0.4
and region_1.9
for readability of the HLO. Some computations end up with hundreds of reducers.
Metadata
Metadata
Assignees
Labels
No labels
Activity