Skip to content

Feat/data parallelism#832

Closed
JamesBrianD wants to merge 13 commits into
mainfrom
feat/data-parallelism
Closed

Feat/data parallelism#832
JamesBrianD wants to merge 13 commits into
mainfrom
feat/data-parallelism

Conversation

@JamesBrianD
Copy link
Copy Markdown
Collaborator

Motivation

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

  • Please use English, otherwise it will be closed.
  • The purpose of the PR, or link existing issues this PR will resolve.
  • The test plan, such as providing test command.
  • (Optional) The necessary documentation update.

@JamesBrianD JamesBrianD marked this pull request as ready for review February 17, 2026 08:58
@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello @JamesBrianD, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the SGLang JAX inference engine by introducing comprehensive Data Parallelism (DP) support. This architectural change allows the system to distribute requests and KV cache across multiple devices, improving scalability and resource utilization for large-scale inference workloads. The modifications span across core components, from memory management and request scheduling to model execution and benchmarking utilities, ensuring a robust and efficient DP implementation.

Highlights

  • Data Parallelism Core Integration: Introduced comprehensive Data Parallelism (DP) support across memory management, scheduling, and model execution components, enabling requests and KV cache to be distributed across multiple devices for enhanced scalability.
  • DP-Aware Memory Management: Updated KV cache allocators (TokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, SWATokenToKVPoolAllocator) and prefix caching mechanisms (RadixCache, SWARadixCache) to manage resources and namespaces on a per-DP rank basis.
  • DP-Enabled Request Scheduling: Enhanced the scheduler to distribute incoming requests across DP ranks using configurable policies (round-robin, min-running-queue) and adapted ScheduleBatch to handle requests and metadata in a DP-aware manner.
  • Multimodal and Benchmarking Enhancements: Extended benchmarking tools to support image-based datasets, Mooncake traces, and advanced profiling options, including new arguments for LoRA distribution, multi-turn conversations, and detailed throughput plotting.
  • MoE and Attention Layer Updates: Modified Mixture-of-Experts (MoE) layers (Bailing, Grok, Qwen2, Qwen3) and attention mechanisms to correctly handle DP sharding for inputs, outputs, and intermediate states, ensuring proper data distribution during model execution.
  • Improved Debugging and Testing: Added new unit and integration test files (test_paged_allocator_multi_dp.py, test_flashattention_dp.py, test_schedule_batch_dp.py) and a script (inspect_expert_dist.py) to validate DP functionality and aid in debugging expert distribution.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • .gitignore
    • Added new entries for eplb/expert_balance_... directories.
  • benchmark/moe/bench_fused_moe.py
    • Updated _estimate_vmem_bytes to include se_w1_scale, se_w3_scale, se_w2_scale in total_bytes.
    • Introduced _bt_allowed function and applied it to filter bt_candidates.
    • Added validation for bt in validate function.
    • Modified run_all to accept return_results and token_valid_ratios, and to return a list of results.
    • Initialized results list in run_all.
    • Added default_ms tracking in run_with_mask.
    • Populated results list with benchmark metrics.
  • python/sgl_jax/bench_one_batch.py
    • Modified extend function to wrap reqs in a list for reqs_per_dp and explicitly set dp_size=1 and chunked_reqs=None in ScheduleBatch.init_new.
  • python/sgl_jax/bench_serving.py
    • Added imports: copy, importlib.util, io, shutil, uuid, Callable, deepcopy, replace, lru_cache, pybase64, datasets, PIL.Image.
    • Removed imports: requests.adapters, urllib3.util.retry.
    • Added constants: _ROUTING_KEY_HEADER, TERM_PLOTLIB_AVAILABLE.
    • Modified RequestFuncInput to include image_data, timestamp, routing_key, and support list types for prompt.
    • Modified RequestFuncOutput to include text_chunks, start_time.
    • Modified get_auth_headers to also check for API_KEY environment variable.
    • Added parse_custom_headers and get_request_headers functions.
    • Updated async_request_trt_llm, async_request_openai_completions, async_request_truss, async_request_sglang_generate to use get_request_headers, include routing_key in headers, image_data in payload, and provide more detailed error messages.
    • Added async_request_openai_chat_completions for OpenAI chat API.
    • Modified async_request_profile to send specific profiling activities and parameters.
    • Added _build_profile_urls and _call_profile_pd for PD separated profiling.
    • Added get_processor for multimodal models.
    • Modified get_dataset to support image, mooncake, and custom datasets, and updated mmmu dataset handling to use processor.
    • Added sglang-oai-chat, vllm-chat, lmdeploy-chat backends to ASYNC_REQUEST_FUNCS.
    • Modified BenchmarkMetrics to include total_input_text, total_input_vision, p90_e2e_latency_ms, max_output_tokens_per_s, max_concurrent_requests.
    • Updated SHAREGPT_URL to use SHAREGPT_REPO_ID and SHAREGPT_FILENAME from Hugging Face. Added MOONCAKE_DATASET_URL.
    • Modified download_and_cache_file to use requests.get directly. Added download_and_cache_hf_file.
    • Modified DatasetRow to include text_prompt_len, vision_prompt_len, timestamp, routing_key.
    • Added get_mooncake_request_over_time for Mooncake dataset scheduling.
    • Added compute_random_lens, parse_image_resolution, create_mm_data_row, sample_image_requests, get_available_tokens, gen_mm_prompt.
    • Included args.seed in the cache key for get_gen_prefix_cache_path.
    • Modified sample_generated_shared_prefix_requests to add range_ratio, gsp_send_routing_key, gsp_num_turns, gsp_ordered arguments and logic for multi-turn conversations and routing keys.
    • Modified get_request to add use_trace_timestamps and slowdown_factor for Mooncake dataset.
    • Modified calculate_metrics to include accept_length, plot_throughput, total_input_text, total_input_vision, max_output_tokens_per_s, max_concurrent_requests, p90_e2e_latency_ms, mean_tpot_ms, median_tpot_ms, p99_tpot_ms, and implemented throughput plotting.
    • Added MULTI_TURN_BACKENDS and wrap_multi_turn_request_func.
    • Modified benchmark to add lora_request_distribution, lora_zipf_alpha, use_trace_timestamps, mooncake_slowdown_factor, mooncake_num_rounds, profile_prefill_url, profile_decode_url arguments. Updated warmup logic for Mooncake. Added LoRA distribution logic. Updated profiler start/stop for PD separated mode.
    • Modified run_benchmark to add plot_throughput, use_trace_timestamps, mooncake_slowdown_factor, mooncake_num_rounds, served_model_name, print_requests to args. Updated get_dataset call. Added LoRA distribution validation.
    • Added numerous new CLI arguments for image datasets, Mooncake dataset, GSP dataset, LoRA distribution, profiling, and custom headers in add_cli_args.
  • python/sgl_jax/srt/entrypoints/engine.py
    • Removed conditional if server_args.dp_size == 1 for launching scheduler processes and threads, implying DP is now always active or handled uniformly.
  • python/sgl_jax/srt/eplb/expert_location.py
    • Updated sharding for logical_to_rank_dispatch_physical_map, logical_to_all_physical_map, logical_to_all_physical_map_num_valid, physical_to_logical_map to P(None).
    • Added out_sharding=P("data", None) for jax.Array.at operations on logical_to_all_physical_map_num_valid and logical_to_all_physical_map.
  • python/sgl_jax/srt/kernels/fused_moe/v1/kernel.py
    • Added validation to validate_fused_moe_block_config to ensure bt (batch size) is 2, 4, 8, or a multiple of 8.
  • python/sgl_jax/srt/kernels/fused_moe/v1/tuned_block_configs.py
    • Added new tuned block configurations for 32768 tokens for both 256 and 288 experts.
  • python/sgl_jax/srt/kernels/ragged_paged_attention/ragged_paged_attention.py
    • Added more descriptive assertion messages for shape and dtype mismatches in ref_ragged_paged_attention_fused, ref_ragged_paged_attention, _ragged_paged_attention_kernel, strided_load, strided_load_bkv, broadcast_minor, merge_kv, prepare_kv, prepare_inputs, ragged_paged_attention, prepare_kv_cache_fused.
  • python/sgl_jax/srt/kernels/update_kv_cache/update_kv_cache.py
    • Extracted the core logic of kv_cache_update into a new function kv_cache_update_impl.
    • Modified kv_cache_update to call kv_cache_update_impl.
  • python/sgl_jax/srt/layers/attention/flashattention_backend.py
    • Removed num_seqs from FlashAttentionMetadata.
    • Added attention_data_partition_axis to FlashAttentionConfig.
    • Modified get_forward_metadata to handle DP ranks for page_indices, cu_q_lens, cu_kv_lens, and distribution. Updated sharding for these to P("data").
    • Updated in_specs and out_specs in __call__ to include self.attention_data_partition_axis.
  • python/sgl_jax/srt/layers/linear.py
    • Updated output_pspec and in_specs/out_specs to include "data" partition for sharding.
  • python/sgl_jax/srt/layers/logits_processor.py
    • Added logits_indices to LogitsMetadata.
    • Modified from_model_worker_batch to use P("data") sharding for device arrays.
    • Added _select_hidden_states method for sharded hidden state selection.
    • Modified __call__ to use logits_metadata.logits_indices for last_index and _select_hidden_states for pruning.
    • Updated jnp.dot to specify out_sharding=NamedSharding(self.mesh, P("data", "tensor")).
  • python/sgl_jax/srt/layers/moe.py
    • Modified quantize_weights for MoEBlock and FusedMoEBlock to use scale_placeholder_shape and ep_scale_placeholder_shape that are divisible by the mesh size for static checkpoints.
    • Updated output resharding in FusedMoEBlock.__call__ to P("data", None).
  • python/sgl_jax/srt/layers/routed_experts_capturer.py
    • Added mesh argument to RoutedExpertsCapturer.create and _RoutedExpertsCapturerReal.__init__.
    • Modified on_forward_end to handle real_bs_per_dp and per_dp_bs_size for DP.
    • Added reset method to _ExpertBalanceAnalyzer.
    • Modified add_decode_step and add_topk_ids in _ExpertBalanceAnalyzer and _ExpertDistributionRecorder to handle per-DP rank batch sizes and token counts.
  • python/sgl_jax/srt/layers/sampler.py
    • Updated logits resharding to P("data", None) and batch_next_token_ids, log_probs resharding to P("data") and P("data", "tensor") respectively.
  • python/sgl_jax/srt/managers/io_struct.py
    • Added dp_rank: int | None to TokenizedGenerateReqInput.
  • python/sgl_jax/srt/managers/schedule_policy.py
    • Changed jax.numpy to numpy.
    • Modified RadixKey to include dp_rank.
    • Updated _compute_prefix_matches to use dp_rank in RadixKey.
    • Modified SchedulerPolicy to handle per-DP rank token offsets (rem_total_token_offset, cur_rem_token_offset, rem_chunk_tokens_list), req_states, can_run_list, new_chunked_reqs as lists/dicts indexed by DP rank.
    • Updated rem_total_tokens and cur_rem_tokens properties to aggregate across DP ranks.
    • Modified _update_prefill_budget to accept dp_rank.
    • Updated add_chunked_req and add_one_req to use dp_rank for allocator and budget updates.
  • python/sgl_jax/srt/managers/scheduler.py
    • Added dp_size, dp_schedule_policy to Scheduler.__init__.
    • Modified create_device_mesh to use [self.dp_size, self.tp_size // self.dp_size] for ici_parallelism.
    • Adjusted max_running_requests to be divisible by dp_size.
    • Initialized running_batch and chunked_reqs to be DP-aware.
    • Added _select_round_robin_dp, _select_min_running_dp, select_dp_for_request for DP scheduling.
    • Modified event_loop_normal and event_loop_overlap to use select_dp_for_request.
    • Updated tmp_batch initialization in event_loop_overlap.
    • Added dp_rank to Req initialization in handle_generate_request.
    • Updated get_internal_state to reflect DP changes.
    • Modified _batch_size and flush_cache to be DP-aware.
    • Updated _get_token_info and _get_swa_token_info to sum across DP ranks.
    • Modified get_next_batch_to_run to handle chunked requests and filtering per DP rank.
    • Updated run_batch to use _extract_dp_output_ids and collect extend_input_len_per_req, extend_logprob_start_len_per_req from all DP ranks.
    • Added _extract_dp_output_ids function.
    • Modified get_idle_batch to be DP-aware.
    • Updated abort_request and pause_generation to handle DP-aware request lists.
  • python/sgl_jax/srt/managers/scheduler_metrics_mixin.py
    • Modified log_decode_stats to use batch.batch_size() and added per-DP running request count.
  • python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py
    • Modified process_batch_result_prefill and process_batch_result_decode to iterate over DP ranks and handle dp_output_ids and req.dp_rank.
  • python/sgl_jax/srt/managers/tokenizer_manager.py
    • Changed _Communicator dp_size argument to 1 for all communicators, implying they operate globally or are not DP-sharded.
  • python/sgl_jax/srt/managers/tp_worker.py
    • Added dp_size to TPWorker.__init__.
    • Defined attention_tp_size.
    • Modified num_kv_heads calculation to use attention_tp_size.
    • Added dp_size to init_memory_pool.
    • Modified attn_backend_limit calculation to multiply by dp_size.
    • Adjusted max_running_requests to be divisible by dp_size.
    • Modified precompile_bs_paddings to ensure bs >= self.dp_size.
    • Updated precompile_cache_loc_paddings calculation.
    • Modified normalize_token_paddings to multiply default token paddings by dp_size and ensure items are divisible by dp_size.
    • Added dp_size and per_dp_bs_size to generate_model_worker_batch.
    • Modified resolve_future_token_ids and set_future_token_ids to accept mesh.
    • Updated sampling_metadata initialization in forward_batch_generation to use 0 for padding_size.
  • python/sgl_jax/srt/managers/tp_worker_overlap_thread.py
    • Added async_gather_fn for jax.jit with replicated sharding.
    • Modified resolve_future_token_ids and set_future_token_ids to accept mesh.
    • Updated next_token_ids to be gathered using async_gather_fn.
    • Updated sampling_metadata initialization in forward_batch_generation to use 0 for padding_size.
  • python/sgl_jax/srt/managers/utils.py
    • Added NamedSharding and P imports.
    • Modified resolve_future_token_ids and set_future_token_ids to accept mesh and use jax.sharding.reshard for DP compatibility.
  • python/sgl_jax/srt/mem_cache/allocator.py
    • Added dp_size to BaseAllocator.__init__ and size_per_rank.
    • Modified available_size, alloc, free, merge_and_sort_free, clear to accept dp_rank and operate on per-DP rank data structures.
    • Updated SWATokenToKVPoolAllocator to be DP-aware.
  • python/sgl_jax/srt/mem_cache/base_prefix_cache.py
    • Modified evictable_size, full_evictable_size, swa_evictable_size, protected_size to accept dp_rank.
  • python/sgl_jax/srt/mem_cache/chunk_cache.py
    • Modified cache_finished_req and cache_unfinished_req to use req.dp_rank for allocator operations.
  • python/sgl_jax/srt/mem_cache/common.py
    • Modified alloc_token_slots, alloc_paged_token_slots_extend, evict_from_tree_cache, available_and_evictable_str to accept dp_rank.
  • python/sgl_jax/srt/mem_cache/memory_pool.py
    • Added dp_size to MHATokenToKVPool.__init__.
    • Modified _create_buffers to assert size % dp_size == 0 and adjust fused_buffer_shape.
    • Modified _calculate_memory_usage and get_kv_size_bytes to use dp_size.
    • Added attention_data_partition_axis to _set_fused_kv_buffer and update_fused_kv_cache.
    • Modified update_fused_kv_cache_vectorized to use jax.shard_map with data_partition_axis.
  • python/sgl_jax/srt/mem_cache/radix_cache.py
    • Added dp_rank to RadixKey.
    • Modified _check_composite_key to validate dp_rank.
    • Updated get_child_key to incorporate dp_rank for namespace isolation.
    • Modified reset to use defaultdict(int) for evictable_size_ and protected_size_.
    • Updated match_prefix, insert, cache_finished_req, cache_unfinished_req, evict, inc_lock_ref, dec_lock_ref, evictable_size, protected_size to handle dp_rank.
  • python/sgl_jax/srt/mem_cache/swa_radix_cache.py
    • Modified _swa_eff_len to accept dp_rank.
    • Updated reset to use defaultdict(int) for size tracking.
    • Modified match_prefix, insert, cache_finished_req, cache_unfinished_req, evict, inc_lock_ref, dec_lock_ref, full_evictable_size, swa_evictable_size, full_protected_size, swa_protected_size to handle dp_rank.
  • python/sgl_jax/srt/model_executor/forward_batch_info.py
    • Updated sharding for device_array calls to P("data").
  • python/sgl_jax/srt/model_executor/model_runner.py
    • Added dp_size to ModelRunner.__init__.
    • Defined attention_tp_size.
    • Modified num_kv_heads calculation to use attention_tp_size.
    • Added dp_size to init_memory_pool.
    • Modified max_total_num_tokens calculation to multiply by dp_size.
    • Updated token_to_kv_pool and token_to_kv_pool_allocator initialization to pass dp_size.
    • Modified _set_kv_cache_after_forward to use attention_data_partition_axis.
    • Updated MockModelRunner to include dp_size and attention_tp_size.
  • python/sgl_jax/srt/models/bailing_moe.py
    • Added mesh to BailingMoEAttention.__init__.
    • Updated q, k, v reshape operations to include out_sharding with P("data", "tensor").
    • Resharded topk_ids to P(None).
  • python/sgl_jax/srt/models/grok.py
    • Updated q, k, v reshape operations to include out_sharding with P("data", "tensor").
    • Resharded topk_ids to P(None).
  • python/sgl_jax/srt/models/qwen2_moe.py
    • Added mesh to Qwen2MoeAttention.__init__.
    • Updated q, k, v reshape operations to include out_sharding with P("data", "tensor").
    • Resharded topk_ids to P(None).
  • python/sgl_jax/srt/models/qwen3.py
    • Added mesh to Qwen3Attention.__init__.
    • Updated q, k, v reshape operations to include out_sharding with P("data", "tensor", None).
  • python/sgl_jax/srt/models/qwen3_moe.py
    • Added mesh to Qwen3MoeAttention.__init__.
    • Updated q, k, v reshape operations to include out_sharding with P("data", "tensor").
    • Resharded topk_ids to P(None).
  • python/sgl_jax/srt/sampling/penaltylib/orchestrator.py
    • Modified reqs property to collect requests from all DP ranks.
  • python/sgl_jax/srt/sampling/sampling_batch_info.py
    • Modified from_model_worker_batch to use P("data") sharding.
    • Updated from_schedule_batch to accept ScheduleReqsInfo and the full ScheduleBatch for penalty orchestrator.
  • python/sgl_jax/srt/server_args.py
    • Added dp_schedule_policy to ServerArgs and add_cli_args.
  • python/sgl_jax/srt/speculative/eagle_worker.py
    • Added a comment about temperatures reshaping.
  • python/sgl_jax/test/mem_cache/test_kv_cache.py
    • Modified generate_test_data to use NamedSharding with P("data", "tensor", None) and P("data").
    • Modified expected_update_kv_cache to convert to numpy for reference and back to JAX array with sharding.
    • Removed redundant padding verification.
    • Added sharding to loc in test_all_padding_tokens.
    • Added sharding to k, v, k_cache, v_cache, loc in test_kv_cache_update_multiple_segments_with_padding.
    • Resharded updated_k_cache and k to P() for comparison.
  • python/sgl_jax/test/mem_cache/test_paged_allocator_multi_dp.py
    • New file, adds tests for PagedTokenToKVPoolAllocator with multiple DP ranks.
  • python/sgl_jax/test/mem_cache/test_radix_cache.py
    • Added dp_rank to RadixKey in MockRequest.
    • Added several new test cases for DP rank namespace isolation, including test_dp_rank_namespace_isolation, test_dp_rank_none_shared_namespace, test_dp_rank_none_vs_explicit_rank, test_combined_extra_key_and_dp_rank, test_dp_rank_preserves_on_slicing.
  • python/sgl_jax/test/mem_cache/test_swa_radix_cache.py
    • Added _create_swa_allocator helper.
    • Modified _swa_eff_len to accept dp_rank.
    • Updated reset to use defaultdict(int) for size tracking.
    • Modified match_prefix, insert, cache_finished_req, cache_unfinished_req, evict, inc_lock_ref, dec_lock_ref, full_evictable_size, swa_evictable_size, full_protected_size, swa_protected_size to handle dp_rank.
    • Added new test cases for DP rank namespace isolation, including test_dp_rank_namespace_isolation, test_dp_rank_none_shared_namespace, test_dp_rank_none_vs_explicit_rank, test_combined_extra_key_and_dp_rank, test_dp_rank_preserves_on_slicing, test_swa_free_group_batching_multi_rank.
  • python/sgl_jax/test/test_flashattention.py
    • Added logits_indices, dp_size, per_dp_bs_size to generate_model_worker_batch.
    • Removed forward_batch.attn_backend.forward_metadata.num_seqs from run_test.
  • python/sgl_jax/test/test_flashattention_dp.py
    • New file, adds tests for FlashAttention with Data Parallelism.
  • scripts/inspect_expert_dist.py
    • New file, a script to inspect expert distribution .npy files.
  • test/srt/run_suite.py
    • Added TestFile("python/sgl_jax/test/mem_cache/test_paged_allocator_multi_dp.py", 1).
  • test/srt/test_schedule_batch_dp.py
    • New file, adds unit tests for ScheduleBatch Data Parallelism merging logic.
Activity
  • JamesBrianD opened this pull request.
  • All changes were introduced in the initial commit.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch from ce35b74 to 24ffa11 Compare February 17, 2026 09:12
JamesBrianD and others added 6 commits March 2, 2026 17:21
feat: allocator with dp (#591)

refactor allocator (#608)

change schedule batch (#613)

change logitprocessor and fix scheduler bugs

feat: rpa dp (#638)

* feat: rpa dp

* add debug logs

fix overlap dp (#663)

fix: merge cache loc (#665)

fix: overlap schedule req over max batch size of per dp rank (#669)

fix: precompile dp (#671)

fix: communicator fan-out always be 1 if dp_size > 1 (#672)

fix: memory leak (#673)

feat: multi host support dp (#678)

* mhdp

fix: sharding not match when enable dp (#680)

Co-Authored-By: Prayer <prayer@primatrix.ai>
Co-authored-by: leos <leos@primatrix.ai>

bench mark fused moe and change decode block size

min_running_queue schedule for dp

change bench fused scripts
* dp schedule

* chore: ignore .worktrees/ directory

* Add DP-safe mixin scheduling and FA metadata checks
@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch 2 times, most recently from 78d04e5 to e3b40ff Compare March 3, 2026 06:39
@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch 2 times, most recently from bff50e8 to 34c9671 Compare March 3, 2026 07:28
@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch 11 times, most recently from b6d7606 to 19cec01 Compare March 3, 2026 09:38
@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch from 19cec01 to 464d10f Compare March 3, 2026 09:43
JamesBrianD and others added 2 commits March 4, 2026 14:55
Refactor get_top_logprobs and get_token_ids_logprobs to return flat
tensors instead of per-request nested structures. In DP mode, padding
slots caused shape mismatches when jnp.array() tried to concatenate
arrays with different dimensions. Now logits_processor returns flat
tensors and consumers slice by logprob_pt offset on CPU side.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Apply the same flat tensor pattern to sampler.py's get_top_logprobs
and get_token_ids_logprobs (decode stage). Padding slots with k=0
caused ragged shape mismatches. Now return [batch_size, max_k] flat
tensors and truncate to per-request k on CPU side.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch from d5058d9 to 2ef0833 Compare March 4, 2026 08:29
eec2620 introduced _gather_next_token_ids which gathers sharded JAX
arrays to replicated sharding, but did not convert the result to CPU.
This left next_token_ids as a JAX on-device array, causing downstream
unhashable type errors when token ids were used in set lookups
(check_finished). Add device_get + tolist at the source.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@JamesBrianD JamesBrianD force-pushed the feat/data-parallelism branch from e56c621 to 1e7e066 Compare March 4, 2026 08:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants