Skip to content

add dummy variable for logprob to avoid cache collision#1895

Merged
sixiang-google merged 1 commit intomainfrom
sixiang/sc-disagg-debug
Mar 12, 2026
Merged

add dummy variable for logprob to avoid cache collision#1895
sixiang-google merged 1 commit intomainfrom
sixiang/sc-disagg-debug

Conversation

@sixiang-google
Copy link
Collaborator

@sixiang-google sixiang-google commented Mar 10, 2026

Description

Including logprobs cause compilation failure in disagg.
Checking upon compilation log:

...
(EngineCore_DP0 pid=979443) INFO 03-10 03:33:17 [compilation_manager.py:75] Precompile worker0 sample --> {'num_reqs': 8, 'do_sampling': True}
(EngineCore_DP0 pid=979443) WARNING:2026-03-10 03:33:18,021:jax._src.pjit:1108: TRACING CACHE MISS at /home/sixiang_google_com/tpu-inference/tpu_inference/runner/compilation_manager.py:77:17 (CompilationManager._run_compilation) costing 27.935 ms because:
(EngineCore_DP0 pid=979443)   never seen function:
(EngineCore_DP0 pid=979443)     sample id=125890840909344 defined at /home/sixiang_google_com/tpu-inference/tpu_inference/layers/jax/sample/sampling.py:29
(EngineCore_DP0 pid=979443) WARNING:jax._src.pjit:TRACING CACHE MISS at /home/sixiang_google_com/tpu-inference/tpu_inference/runner/compilation_manager.py:77:17 (CompilationManager._run_compilation) costing 27.935 ms because:
(EngineCore_DP0 pid=979443)   never seen function:
(EngineCore_DP0 pid=979443)     sample id=125890840909344 defined at /home/sixiang_google_com/tpu-inference/tpu_inference/layers/jax/sample/sampling.py:29
(EngineCore_DP0 pid=979443) WARNING:2026-03-10 03:33:18,064:jax._src.compiler:112: PERSISTENT COMPILATION CACHE MISS for 'jit_sample' with key 'jit_sample-e9af046e56b26fbac4b1d0dee05f0a846eefe65864fbd46661ea43f985f4aee5'
(EngineCore_DP0 pid=979443) WARNING:jax._src.compiler:PERSISTENT COMPILATION CACHE MISS for 'jit_sample' with key 'jit_sample-e9af046e56b26fbac4b1d0dee05f0a846eefe65864fbd46661ea43f985f4aee5'
(EngineCore_DP0 pid=979443) WARNING:2026-03-10 03:33:19,665:jax._src.compilation_cache:277: Writing jit_sample to persistent compilation cache with key 'jit_sample-e9af046e56b26fbac4b1d0dee05f0a846eefe65864fbd46661ea43f985f4aee5'
**(EngineCore_DP0 pid=979443) WARNING:jax._src.compilation_cache:Writing jit_sample to persistent compilation cache with key 'jit_sample-e9af046e56b26fbac4b1d0dee05f0a846eefe65864fbd46661ea43f985f4aee5'**
(EngineCore_DP0 pid=979443) INFO 03-10 03:33:19 [compilation_manager.py:80] Compilation finished in 1.69 [secs].
(EngineCore_DP0 pid=979443) INFO 03-10 03:33:19 [compilation_manager.py:75] Precompile worker0 sample --> {'num_reqs': 8, 'do_sampling': True}
(EngineCore_DP0 pid=979443) WARNING:2026-03-10 03:33:19,696:jax._src.pjit:1108: TRACING CACHE MISS at /home/sixiang_google_com/tpu-inference/tpu_inference/runner/compilation_manager.py:77:17 (CompilationManager._run_compilation) costing 9.182 ms because:
(EngineCore_DP0 pid=979443)   for sample defined at /home/sixiang_google_com/tpu-inference/tpu_inference/layers/jax/sample/sampling.py:29
(EngineCore_DP0 pid=979443)   all previously seen cache keys are different. Closest previous key:
(EngineCore_DP0 pid=979443)   * key with different input pytree:
(EngineCore_DP0 pid=979443)       now: PyTreeDef(((*, *, CustomNode(TPUSupportedSamplingMetadata[(True, False)],...
(EngineCore_DP0 pid=979443)       before: PyTreeDef(((*, *, CustomNode(TPUSupportedSamplingMetadata[(True, True)], ...
(EngineCore_DP0 pid=979443)       * at args[2], now <class 'tpu_inference.layers.jax.sample.sampling_metadata.TPUSupportedSamplingMetadata'> with pytree metadata (True, False) and before <class 'tpu_inference.layers.jax.sample.sampling_metadata.TPUSupportedSamplingMetadata'> with pytree metadata (True, True), so the pytree node metadata does not match
(EngineCore_DP0 pid=979443) WARNING:jax._src.pjit:TRACING CACHE MISS at /home/sixiang_google_com/tpu-inference/tpu_inference/runner/compilation_manager.py:77:17 (CompilationManager._run_compilation) costing 9.182 ms because:
(EngineCore_DP0 pid=979443)   for sample defined at /home/sixiang_google_com/tpu-inference/tpu_inference/layers/jax/sample/sampling.py:29
(EngineCore_DP0 pid=979443)   all previously seen cache keys are different. Closest previous key:
(EngineCore_DP0 pid=979443)   * key with different input pytree:
(EngineCore_DP0 pid=979443)       now: PyTreeDef(((*, *, CustomNode(TPUSupportedSamplingMetadata[(True, False)],...
(EngineCore_DP0 pid=979443)       before: PyTreeDef(((*, *, CustomNode(TPUSupportedSamplingMetadata[(True, True)], ...
(EngineCore_DP0 pid=979443)       * at args[2], now <class 'tpu_inference.layers.jax.sample.sampling_metadata.TPUSupportedSamplingMetadata'> with pytree metadata (True, False) and before <class 'tpu_inference.layers.jax.sample.sampling_metadata.TPUSupportedSamplingMetadata'> with pytree metadata (True, True), so the pytree node metadata does not match
(EngineCore_DP0 pid=979443) INFO 03-10 03:33:19 [compilation_manager.py:80] Compilation finished in 0.09 [secs].
(EngineCore_DP0 pid=979443) INFO 03-10 03:33:19 [compilation_manager.py:75] Precompile worker0 sample --> {'num_reqs': 8, 'do_sampling': False}
...

It seems that no additional compilation is involved by introducing the logprobs = True flag. Hence removing it

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@github-actions
Copy link

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a Github issue, please include a link, e.g.,:
FIXES: #123456

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@sixiang-google sixiang-google added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
@sixiang-google sixiang-google force-pushed the sixiang/sc-disagg-debug branch 2 times, most recently from 496e06d to 1cbe63c Compare March 11, 2026 22:24
@sixiang-google sixiang-google changed the title remove logprob compilation in sampling add dummy variable for logprob to avoid cache collision Mar 11, 2026
@sixiang-google sixiang-google force-pushed the sixiang/sc-disagg-debug branch from 1cbe63c to 6233966 Compare March 11, 2026 22:34
@sixiang-google sixiang-google merged commit b3a6268 into main Mar 12, 2026
42 checks passed
@sixiang-google sixiang-google deleted the sixiang/sc-disagg-debug branch March 12, 2026 02:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants