Rich interactive visualization of Jax and XLA computational graph IRs.
- Dump XLA computational graph and IRs
- Profile with NVTX, nsight systems, nsight compute
- Display interactive XLA graphviz (jax.XlaComputation.as_hlo_dot_graph)
npm install
npm run dev # Start parcel webserver
python3 jaxviz/jaxviz.py --file jaxviz/examples/cartpole_mlp/1_dot/main.py
XLA_FLAGS='--xla_dump_to=dump --xla_dump_hlo_as_html --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.*' python3 jaxviz/examples/cartpole_mlp/99_full/main.py
# In container:
nsys profile --force-overwrite=true --duration=20 --stats=true -o my_test python3 jaxviz/examples/cartpole_mlp/99_full/main.py
# On laptop:
scp eco-13:/home/dans/jaxviz/my_test.qdrep .
# In container:
ncu -o profile -f --set full --kernel-id ::fusion_3:6 python3 jaxviz/examples/cartpole_mlp/99_full/main.py
# On laptop:
scp eco-13:/home/dans/jaxviz/profile.ncu-rep .
ssh -N -L 8101:localhost:8101 -L 8102:localhost:8102 -L 8103:localhost:8103 -L 8104:localhost:8104 eco-13 &