Conversation
DescriptionStart with a short description of what the PR does and how this is a change from The rest of the description includes relevant details and context, examples:
If the change fixes a Github issue, please include a link, e.g.,: TestsPlease describe how you tested this change, and include any instructions and/or ChecklistBefore submitting this PR, please make sure:
|
a070f83 to
9ca3e99
Compare
tpu_inference/runner/tpu_runner.py
Outdated
| self.devices = devices | ||
| self.dtype = self.model_config.dtype | ||
| self.maybe_forbid_compile = runner_utils.ForbidCompile( | ||
| self.maybe_forbid_compile = jax.no_tracing( |
There was a problem hiding this comment.
This "disallow tracing for JIT compilation". How does it raise an Exception when a recompilation happens?
There was a problem hiding this comment.
this line gets triggered: https://github.com/jax-ml/jax/blob/c0b2687098239f19c53a2d55e2f8ec42eafa3e15/jax/_src/pjit.py#L252-L253
but i found that the error is not correctly caught and vllm throws some random error message that isn't really useful. introducing manual try / except code yields better error message like this. let me see if i can clean up that code and push it to this branch.
RuntimeError("re-tracing function compute_logits_func at /home/kyuyeunk_google_com/workspace/tpu-inference/tpu_inference/models/vllm/vllm_model_wrapper.py:292 for `jit`, but 'no_tracing' is set")
There was a problem hiding this comment.
Do you know why the previous approach doesn't work? I remember I fixed some recompilation issue caught by the previous approach.
There was a problem hiding this comment.
not 100% sure what kind of changes were introduced to jax that might have caused the change. but regardless of why, the previous approach was to override the private function (jax._src.interpreters.pxla._cached_lowering_to_hlo ) and check if that gets invoked or not - which is really fragile because 1. JAX might reset the overriding 2. it's using private function so no guarantee that it will always work.
Meanwhile, jax.no_tracing is an official api provided by JAX to check for recompilation. so it should be much more robust.
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
9ca3e99 to
f254a62
Compare
Description
Start with a short description of what the PR does and how this is a change from
the past.
The rest of the description includes relevant details and context, examples:
If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure: