Skip to content

cache_miss caused by lx.linear_solve #169

@maxecharles

Description

@maxecharles

Hi there,

I have discovered that the TridiagonalLinearOperator causes a cache miss to occur under jit, which leads to a silent recompilation. This is similar to this jax issue, however when I did testing with the flags jax_log_compiles=True and jax_explain_cache_misses=True, nothing showed up (a truly pathological silent error) and I'm still not entirely sure why.

I originally noticed this in an interpax function which was slowing down my code by ~an order of magnitude (I have an issue on interpax going into more detail), but after delving into the interpax source code I found it was the lineax.TridiagonalLinearOperator that was causing the cache miss (here).

I've diagnosed this by code profiling with Perfetto, where you can actually see the cache_miss function being called each time.

Here is a minimal reproducible example purely with lineax. I am using lineax=0.0.8, jax=0.7.0, and equinox=0.13.0.

import lineax as lx
import equinox as eqx
import jax
from jax import numpy as jnp

# flags for debugging
jax.config.update("jax_log_compiles", True)
jax.config.update("jax_explain_cache_misses", True)


@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def f(diag, lower_diag, upper_diag, b):
    A = lx.TridiagonalLinearOperator(diag, lower_diag, upper_diag)
    solve = lambda b: lx.linear_solve(A, b, lx.Tridiagonal()).value
    fx = jnp.vectorize(solve, signature="(n)->(n)")(b.T).T
    return fx


# setting up inputs
n = 5
diag = jnp.ones(n)
lower_diag = jnp.zeros(n - 1)
upper_diag = jnp.zeros(n - 1)
b = jnp.linspace(0, 1, n)

# compiling
f(diag, lower_diag, upper_diag, b)
print("Compilation done.")

# running five times and tracing with perfetto
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
    for i in range(5):
        f(diag, lower_diag, upper_diag, b).block_until_ready()

I've attached the output of this script at the end. In it are lots of logs from the "jax_log_compiles" flag but none of them occur under jit, and the "jax_explain_cache_misses" is completely silent.

This is a screenshot of the Perfetto trace. I've also attached the Perfetto trace file here.
Image

Do you have any idea what could be causing this?

Thanks,
Max

Script output:

WARNING:2025-08-19 18:15:49,152:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000162125 sec
WARNING:2025-08-19 18:15:49,162:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,200:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.037410975 sec
WARNING:2025-08-19 18:15:49,252:jax._src.dispatch:198: Finished XLA compilation of jit(convert_element_type) in 0.052269220 sec
WARNING:2025-08-19 18:15:49,255:jax._src.dispatch:198: Finished tracing + transforming broadcast_in_dim for pjit in 0.000164032 sec
WARNING:2025-08-19 18:15:49,255:jax._src.interpreters.pxla:1861: Compiling jit(broadcast_in_dim) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,259:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.003476143 sec
WARNING:2025-08-19 18:15:49,303:jax._src.dispatch:198: Finished XLA compilation of jit(broadcast_in_dim) in 0.043696880 sec
WARNING:2025-08-19 18:15:49,303:jax._src.dispatch:198: Finished tracing + transforming broadcast_in_dim for pjit in 0.000151873 sec
WARNING:2025-08-19 18:15:49,304:jax._src.interpreters.pxla:1861: Compiling jit(broadcast_in_dim) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,306:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001924992 sec
WARNING:2025-08-19 18:15:49,312:jax._src.dispatch:198: Finished XLA compilation of jit(broadcast_in_dim) in 0.006244898 sec
WARNING:2025-08-19 18:15:49,313:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000132084 sec
WARNING:2025-08-19 18:15:49,313:jax._src.dispatch:198: Finished tracing + transforming true_divide for pjit in 0.000101089 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming true_divide for pjit in 0.000108004 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000113964 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000083208 sec
WARNING:2025-08-19 18:15:49,314:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000087023 sec
WARNING:2025-08-19 18:15:49,315:jax._src.dispatch:198: Finished tracing + transforming _linspace for pjit in 0.002079010 sec
WARNING:2025-08-19 18:15:49,315:jax._src.interpreters.pxla:1861: Compiling jit(_linspace) with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-08-19 18:15:49,319:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(_linspace) in 0.004170895 sec
WARNING:2025-08-19 18:15:49,341:jax._src.dispatch:198: Finished XLA compilation of jit(_linspace) in 0.021125078 sec
WARNING:2025-08-19 18:15:49,342:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000157833 sec
WARNING:2025-08-19 18:15:49,342:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000091791 sec
WARNING:2025-08-19 18:15:49,342:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000079870 sec
WARNING:2025-08-19 18:15:49,343:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.000072956 sec
WARNING:2025-08-19 18:15:49,344:jax._src.dispatch:198: Finished tracing + transforming _squeeze for pjit in 0.000031948 sec
WARNING:2025-08-19 18:15:49,344:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000031948 sec
WARNING:2025-08-19 18:15:49,346:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000039816 sec
WARNING:2025-08-19 18:15:49,346:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000141859 sec
WARNING:2025-08-19 18:15:49,347:jax._src.dispatch:198: Finished tracing + transforming greater for pjit in 0.000206947 sec
WARNING:2025-08-19 18:15:49,347:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000099897 sec
WARNING:2025-08-19 18:15:49,347:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000038147 sec
WARNING:2025-08-19 18:15:49,348:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000409126 sec
WARNING:2025-08-19 18:15:49,348:jax._src.dispatch:198: Finished tracing + transforming less for pjit in 0.000158072 sec
WARNING:2025-08-19 18:15:49,349:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000149965 sec
WARNING:2025-08-19 18:15:49,350:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000894070 sec
WARNING:2025-08-19 18:15:49,350:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000102282 sec
WARNING:2025-08-19 18:15:49,351:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000064850 sec
WARNING:2025-08-19 18:15:49,353:jax._src.dispatch:198: Finished tracing + transforming isfinite for pjit in 0.000086069 sec
WARNING:2025-08-19 18:15:49,354:jax._src.dispatch:198: Finished tracing + transforming invert for pjit in 0.000070810 sec
WARNING:2025-08-19 18:15:49,354:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000190973 sec
WARNING:2025-08-19 18:15:49,354:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000198126 sec
WARNING:2025-08-19 18:15:49,355:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000075817 sec
WARNING:2025-08-19 18:15:49,355:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(bool[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,357:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.001612186 sec
WARNING:2025-08-19 18:15:49,362:jax._src.dispatch:198: Finished XLA compilation of jit(convert_element_type) in 0.005482912 sec
WARNING:2025-08-19 18:15:49,363:jax._src.dispatch:198: Finished tracing + transforming bitwise_and for pjit in 0.000178814 sec
WARNING:2025-08-19 18:15:49,363:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000038147 sec
WARNING:2025-08-19 18:15:49,364:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000247240 sec
WARNING:2025-08-19 18:15:49,365:jax._src.dispatch:198: Finished tracing + transforming equal for pjit in 0.000129938 sec
WARNING:2025-08-19 18:15:49,368:jax._src.dispatch:198: Finished tracing + transforming not_equal for pjit in 0.000437021 sec
WARNING:2025-08-19 18:15:49,407:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000070810 sec
WARNING:2025-08-19 18:15:49,408:jax._src.dispatch:198: Finished tracing + transforming branched_error_if_impl for pjit in 0.038447857 sec
WARNING:2025-08-19 18:15:49,408:jax._src.dispatch:198: Finished tracing + transforming _fn for pjit in 0.062253237 sec
WARNING:2025-08-19 18:15:49,408:jax._src.dispatch:198: Finished tracing + transforming linear_solve for pjit in 0.064262867 sec
WARNING:2025-08-19 18:15:49,409:jax._src.dispatch:198: Finished tracing + transforming f for pjit in 0.067122936 sec
WARNING:2025-08-19 18:15:49,409:jax._src.interpreters.pxla:1861: Compiling jit(f) with global shapes and types [ShapedArray(float32[5]), ShapedArray(float32[4]), ShapedArray(float32[4]), ShapedArray(float32[5])]. Argument mapping: (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue).
WARNING:2025-08-19 18:15:49,412:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000206947 sec
WARNING:2025-08-19 18:15:49,412:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000045061 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming greater for pjit in 0.000105143 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000081062 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000029325 sec
WARNING:2025-08-19 18:15:49,413:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000193119 sec
WARNING:2025-08-19 18:15:49,414:jax._src.dispatch:198: Finished tracing + transforming less for pjit in 0.000079870 sec
WARNING:2025-08-19 18:15:49,414:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000105858 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000109911 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming subtract for pjit in 0.000125885 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming true_divide for pjit in 0.000073195 sec
WARNING:2025-08-19 18:15:49,415:jax._src.dispatch:198: Finished tracing + transforming add for pjit in 0.000076056 sec
WARNING:2025-08-19 18:15:49,417:jax._src.dispatch:198: Finished tracing + transforming multiply for pjit in 0.000058889 sec
WARNING:2025-08-19 18:15:49,418:jax._src.dispatch:198: Finished tracing + transforming isfinite for pjit in 0.000057936 sec
WARNING:2025-08-19 18:15:49,418:jax._src.dispatch:198: Finished tracing + transforming invert for pjit in 0.000138998 sec
WARNING:2025-08-19 18:15:49,419:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000274181 sec
WARNING:2025-08-19 18:15:49,419:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000200033 sec
WARNING:2025-08-19 18:15:49,419:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000107050 sec
WARNING:2025-08-19 18:15:49,420:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(bool[])]. Argument mapping: (UnspecifiedValue,).
WARNING:2025-08-19 18:15:49,421:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.001550674 sec
WARNING:2025-08-19 18:15:49,425:jax._src.dispatch:198: Finished XLA compilation of jit(convert_element_type) in 0.003887892 sec
WARNING:2025-08-19 18:15:49,426:jax._src.dispatch:198: Finished tracing + transforming bitwise_and for pjit in 0.000257015 sec
WARNING:2025-08-19 18:15:49,426:jax._src.dispatch:198: Finished tracing + transforming _broadcast_arrays for pjit in 0.000034094 sec
WARNING:2025-08-19 18:15:49,426:jax._src.dispatch:198: Finished tracing + transforming _where for pjit in 0.000252962 sec
WARNING:2025-08-19 18:15:49,427:jax._src.dispatch:198: Finished tracing + transforming equal for pjit in 0.000099182 sec
WARNING:2025-08-19 18:15:49,427:jax._src.dispatch:198: Finished tracing + transforming not_equal for pjit in 0.000099182 sec
WARNING:2025-08-19 18:15:49,434:jax._src.dispatch:198: Finished tracing + transforming <lambda> for pjit in 0.000093937 sec
WARNING:2025-08-19 18:15:49,434:jax._src.dispatch:198: Finished tracing + transforming branched_error_if_impl for pjit in 0.006947041 sec
WARNING:2025-08-19 18:15:49,445:jax._src.dispatch:198: Finished tracing + transforming _reduce_any for pjit in 0.000211954 sec
WARNING:2025-08-19 18:15:49,446:jax._src.dispatch:198: Finished tracing + transforming _reduce_max for pjit in 0.000173092 sec
WARNING:2025-08-19 18:15:49,449:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(f) in 0.040294886 sec
WARNING:2025-08-19 18:15:49,498:jax._src.dispatch:198: Finished XLA compilation of jit(f) in 0.047907829 sec
Compilation done.
2025-08-19 18:15:49.499140: E external/xla/xla/python/profiler/internal/python_hooks.cc:416] Can't import tensorflow.python.profiler.trace
2025-08-19 18:15:49.501289: E external/xla/xla/python/profiler/internal/python_hooks.cc:416] Can't import tensorflow.python.profiler.trace
Open URL in browser: https://ui.perfetto.dev/#!/?url=http://127.0.0.1:9001/perfetto_trace.json.gz
127.0.0.1 - - [19/Aug/2025 18:16:00] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [19/Aug/2025 18:16:00] "OPTIONS /status HTTP/1.1" 501 -
127.0.0.1 - - [19/Aug/2025 18:16:00] code 404, message File not found
127.0.0.1 - - [19/Aug/2025 18:16:00] "POST /status HTTP/1.1" 404 -
127.0.0.1 - - [19/Aug/2025 18:16:00] code 501, message Unsupported method ('OPTIONS')
127.0.0.1 - - [19/Aug/2025 18:16:00] "OPTIONS /perfetto_trace.json.gz HTTP/1.1" 501 -
127.0.0.1 - - [19/Aug/2025 18:16:00] "GET /perfetto_trace.json.gz HTTP/1.1" 200 -

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions