-
Notifications
You must be signed in to change notification settings - Fork 750
Description
This is unrelated to JAX, but I am using it to illustrate what I mean:
Since JAX 0.9.0 (so bb760b0) and I think specifically since #34060 XLA seems to be using a more compact format for stack metadata in HLO output.
So e.g. in
import jax
from jax import numpy as jnp
jitted_fn = jax.jit(lambda x: (jnp.sin(x) / x).sum())
input = jnp.zeros(10)
comp = jitted_fn.lower(input).compile()
print(comp.as_text())I get the following HLO output:
HloModule jit__lambda, is_scheduled=true, entry_computation_layout={(f32[10]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}
FileNames
1 "/.../tmptest.py"
FunctionNames
1 "<module>"
2 "<lambda>"
FileLocations
1 {file_name_id=1 function_name_id=1 line=8 end_line=8 column=7 end_column=29}
2 {file_name_id=1 function_name_id=2 line=4 end_line=4 column=30 end_column=52}
3 {file_name_id=1 function_name_id=2 line=4 end_line=4 column=31 end_column=41}
StackFrames
1 {file_location_id=1 parent_frame_id=1}
2 {file_location_id=2 parent_frame_id=2}
3 {file_location_id=3 parent_frame_id=2}
%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
%reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
%reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(<lambda>)/reduce_sum" stack_frame_id=2}
}
%fused_computation (param_0.3: f32[10]) -> f32[] {
%param_0.3 = f32[10]{0} parameter(0)
%sin.0 = f32[10]{0} sine(%param_0.3), metadata={op_name="jit(<lambda>)/sin" stack_frame_id=3}
%div.0 = f32[10]{0} divide(%sin.0, %param_0.3), metadata={op_name="jit(<lambda>)/div" stack_frame_id=3}
%constant.0 = f32[] constant(0)
ROOT %reduce_sum.0 = f32[] reduce(%div.0, %constant.0), dimensions={0}, to_apply=%region_0.1, metadata={op_name="jit(<lambda>)/reduce_sum" stack_frame_id=2}
}
ENTRY %main.2 (x.1: f32[10]) -> f32[] {
%x.1 = f32[10]{0} parameter(0), metadata={op_name="x"}
ROOT %divide_reduce_fusion = f32[] fusion(%x.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(<lambda>)/reduce_sum" stack_frame_id=2}
}So, the individual HLO operations just have a stack_frame_id and a section at the beginning on where it is pointing at.
This is of course far more compact and less redundant than the previous output, but it also makes quick debugging (especially of Python/JAX code) less obvious.
In this trivial example it is simple, but in any more complex code to quickly see where some XLA operator comes from one has to either write code that parses both the HLO and the stack-data to join them together. Or scroll around the output to mentally unroll the stack indices.
What is the best solution for this?
Would it be possible to add an option (I assume as an XLA flag) that (for debugging) instead uses less efficient output and prints the file/lineno directly into the metadata of every operator?