Skip to content

Fix recompilation check code#1839

Open
kyuyeunk wants to merge 1 commit intomainfrom
kyuyeunk/fix_compute_sharding
Open

Fix recompilation check code#1839
kyuyeunk wants to merge 1 commit intomainfrom
kyuyeunk/fix_compute_sharding

Conversation

@kyuyeunk
Copy link
Collaborator

@kyuyeunk kyuyeunk commented Mar 3, 2026

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@kyuyeunk kyuyeunk requested a review from vanbasten23 as a code owner March 3, 2026 10:35
@kyuyeunk kyuyeunk added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 3, 2026
@github-actions
Copy link

github-actions bot commented Mar 3, 2026

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@kyuyeunk kyuyeunk marked this pull request as draft March 3, 2026 11:29
@kyuyeunk kyuyeunk force-pushed the kyuyeunk/fix_compute_sharding branch from a070f83 to 9ca3e99 Compare March 3, 2026 11:37
@kyuyeunk kyuyeunk marked this pull request as ready for review March 3, 2026 11:37
@kyuyeunk kyuyeunk changed the title Fix compute logits sharding Fix recompilation check code Mar 3, 2026
self.devices = devices
self.dtype = self.model_config.dtype
self.maybe_forbid_compile = runner_utils.ForbidCompile(
self.maybe_forbid_compile = jax.no_tracing(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "disallow tracing for JIT compilation". How does it raise an Exception when a recompilation happens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line gets triggered: https://github.com/jax-ml/jax/blob/c0b2687098239f19c53a2d55e2f8ec42eafa3e15/jax/_src/pjit.py#L252-L253

but i found that the error is not correctly caught and vllm throws some random error message that isn't really useful. introducing manual try / except code yields better error message like this. let me see if i can clean up that code and push it to this branch.

RuntimeError("re-tracing function compute_logits_func at /home/kyuyeunk_google_com/workspace/tpu-inference/tpu_inference/models/vllm/vllm_model_wrapper.py:292 for `jit`, but 'no_tracing' is set")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why the previous approach doesn't work? I remember I fixed some recompilation issue caught by the previous approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure what kind of changes were introduced to jax that might have caused the change. but regardless of why, the previous approach was to override the private function (jax._src.interpreters.pxla._cached_lowering_to_hlo ) and check if that gets invoked or not - which is really fragile because 1. JAX might reset the overriding 2. it's using private function so no guarantee that it will always work.

Meanwhile, jax.no_tracing is an official api provided by JAX to check for recompilation. so it should be much more robust.

Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
@kyuyeunk kyuyeunk force-pushed the kyuyeunk/fix_compute_sharding branch from 9ca3e99 to f254a62 Compare March 6, 2026 04:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants