You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am attempting to use the persistent compilation cache functionality, to avoid losing time on frequent jit recompiling. I am trying to follow the instructions in the provided user guide, but I can not even get the provided quick start example to work.
I am running the following code
import jax
import jax.numpy as jnp
import time
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "./tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "./tmp/jax_cache")
from jax.experimental.compilation_cache import compilation_cache as cc
cc.set_cache_dir("./tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir")
@jax.jit
def f(x):
return x**2
x = float(5)
f(x)
which is essentially copied from the user guide. Indeed, after running the script, a jit cache file is created in the set directory. When re-running the script the function is however recompiled again, rather than being loaded from the cache, which is evident from the runtime. My understanding is that the cache should be automatically be detected and loaded, which does not seem to be the case. I would be very thankful to anyone able to point out where I am going wrong. I am running computations on a CPU, but as far as I know that should not be an issue.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi!
I am attempting to use the persistent compilation cache functionality, to avoid losing time on frequent jit recompiling. I am trying to follow the instructions in the provided user guide, but I can not even get the provided quick start example to work.
I am running the following code
which is essentially copied from the user guide. Indeed, after running the script, a jit cache file is created in the set directory. When re-running the script the function is however recompiled again, rather than being loaded from the cache, which is evident from the runtime. My understanding is that the cache should be automatically be detected and loaded, which does not seem to be the case. I would be very thankful to anyone able to point out where I am going wrong. I am running computations on a CPU, but as far as I know that should not be an issue.
Beta Was this translation helpful? Give feedback.
All reactions