Skip to content

Add flag for showing inline stack frame in HLO output #36953

@Findus23

Description

@Findus23

This is unrelated to JAX, but I am using it to illustrate what I mean:

Since JAX 0.9.0 (so bb760b0) and I think specifically since #34060 XLA seems to be using a more compact format for stack metadata in HLO output.

So e.g. in

import jax
from jax import numpy as jnp

jitted_fn = jax.jit(lambda x: (jnp.sin(x) / x).sum())

input = jnp.zeros(10)

comp = jitted_fn.lower(input).compile()
print(comp.as_text())

I get the following HLO output:

HloModule jit__lambda, is_scheduled=true, entry_computation_layout={(f32[10]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}

FileNames
1 "/.../tmptest.py"

FunctionNames
1 "<module>"
2 "<lambda>"

FileLocations
1 {file_name_id=1 function_name_id=1 line=8 end_line=8 column=7 end_column=29}
2 {file_name_id=1 function_name_id=2 line=4 end_line=4 column=30 end_column=52}
3 {file_name_id=1 function_name_id=2 line=4 end_line=4 column=31 end_column=41}

StackFrames
1 {file_location_id=1 parent_frame_id=1}
2 {file_location_id=2 parent_frame_id=2}
3 {file_location_id=3 parent_frame_id=2}


%region_0.1 (reduce_sum.3: f32[], reduce_sum.4: f32[]) -> f32[] {
  %reduce_sum.3 = f32[] parameter(0), metadata={op_name="reduce_sum"}
  %reduce_sum.4 = f32[] parameter(1), metadata={op_name="reduce_sum"}
  ROOT %reduce_sum.5 = f32[] add(%reduce_sum.3, %reduce_sum.4), metadata={op_name="jit(<lambda>)/reduce_sum" stack_frame_id=2}
}

%fused_computation (param_0.3: f32[10]) -> f32[] {
  %param_0.3 = f32[10]{0} parameter(0)
  %sin.0 = f32[10]{0} sine(%param_0.3), metadata={op_name="jit(<lambda>)/sin" stack_frame_id=3}
  %div.0 = f32[10]{0} divide(%sin.0, %param_0.3), metadata={op_name="jit(<lambda>)/div" stack_frame_id=3}
  %constant.0 = f32[] constant(0)
  ROOT %reduce_sum.0 = f32[] reduce(%div.0, %constant.0), dimensions={0}, to_apply=%region_0.1, metadata={op_name="jit(<lambda>)/reduce_sum" stack_frame_id=2}
}

ENTRY %main.2 (x.1: f32[10]) -> f32[] {
  %x.1 = f32[10]{0} parameter(0), metadata={op_name="x"}
  ROOT %divide_reduce_fusion = f32[] fusion(%x.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(<lambda>)/reduce_sum" stack_frame_id=2}
}

So, the individual HLO operations just have a stack_frame_id and a section at the beginning on where it is pointing at.
This is of course far more compact and less redundant than the previous output, but it also makes quick debugging (especially of Python/JAX code) less obvious.
In this trivial example it is simple, but in any more complex code to quickly see where some XLA operator comes from one has to either write code that parses both the HLO and the stack-data to join them together. Or scroll around the output to mentally unroll the stack indices.

What is the best solution for this?

Would it be possible to add an option (I assume as an XLA flag) that (for debugging) instead uses less efficient output and prints the file/lineno directly into the metadata of every operator?

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions