Skip to content

update gpt-oss-0.12 changes#293

Merged
rebel-ykchoi merged 28 commits intodev-0.12from
dev-0.12-merge-gpt-oss
Jan 27, 2026
Merged

update gpt-oss-0.12 changes#293
rebel-ykchoi merged 28 commits intodev-0.12from
dev-0.12-merge-gpt-oss

Conversation

@rebel-wonsubkim
Copy link
Copy Markdown
Contributor

@rebel-wonsubkim rebel-wonsubkim commented Jan 22, 2026

  1. support gpt-oss model - support mxfp4 quantized method, some kernels and etc
  2. fix v1 vllm online serving for model parallel (data parallel & pipeline parallel)
  3. modify batch bucketing for consistency (by default, decode batch bucket is disabled)
  4. change kv cache max num blocks estimation for hybrid attention model

rebel-ykchoi and others added 21 commits January 6, 2026 16:46
- remove DP padding support in v1 worker
- add validation for DP implementation constraints in v1 worker
- apply token mask to custom MOE kernel router logits
- update default environment variables:
  - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode"
  - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True
- fix DP metadata handling in forward context
- add is_prefills field to RBLNFlashAttentionMetadata
+ add expert_map to handle vllm model parallel

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Signed-off-by: wonsub kim <subang0@rebellions.ai>
+ gpt_oss MLPBlock tp missing

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Signed-off-by: wonsub kim <subang0@rebellions.ai>
+ change available dram size for REBEL architecture
  - ATOM - 16GB
  - REBEL - 140GB

Signed-off-by: wonsub kim <subang0@rebellions.ai>
refactor: improve intermediate tensors management and dummy run logic

- add prepare_dummy_run and dummy_run methods for v1 dp online serving
- remove unused sync_and_slice_intermediate_tensors method
- separate intermediate_tensors into prefill_intermediate_tensors
  and decode_intermediate_tensors
- improve RBLNWorker device environment initialization
  - add support for Ray backend
  - add local_world_size calculation
  - improve device environment variable setup logic
    - make RBLN_DEVICES not coupled with VLLM_RBLN_TP_SIZE
- change LOCAL_RANK to rank in init_worker_distributed_environment
+ add necessary parameters
  --max-model-len, --block-size, --num-hidden-layers, --decode-batch

Signed-off-by: wonsub kim <subang0@rebellions.ai>
+ consider sliding window attention
  - DO NOT count sliding window attention block
  since it shares kv cache block with full attention

+ calculate max num blocks based on assumption that
  entire layers have full attention
  - when calculating available memory, count full attention layer
  not sliding window attention

Signed-off-by: wonsub kim <subang0@rebellions.ai>
- remove unused attn_metadata parameter from RBLNDPMetadata.make()
- remove is_prefills field and related logic from DP metadata
- fix get_tokens_mask() for non-DP case
- refactor dummy run execution with DummyRunState and prepare_dummy_run
- update batch size calculation to account for pipeline parallel size
- add batch_pad parameter to attention metadata builder for PP support
+ consider following issues when calculating max_num_blocks
  - consider gpt-oss-20b scale merge for dequantized version
  - consider SWA(sliding window attention) block share with full attention
  - consider word_embedding param when calculating kernel size
  it is not included into device

Signed-off-by: wonsub kim <subang0@rebellions.ai>
+ batch_attention kernel is optimized version of flash attention kernel for large batch
  - batch attention kernel takes original sequence index
  - in compiler lowering, original sequence index is lowered into following itmes
    - seq_idx - cache target block index
    - seq_offset - cache target block offset
    - dyn_batch - valid batch count for each partition

Signed-off-by: wonsub kim <subang0@rebellions.ai>
+ replace max_batch_size with decode_batch_bucket size
+ by default, disable batch bucketing
  - change limit of bucket

Signed-off-by: wonsub kim <subang0@rebellions.ai>
@rebel-jiwoopark rebel-jiwoopark added the torch.compile torch.compile based implementation label Jan 22, 2026
Comment thread vllm_rbln/v1/worker/rbln_model_runner.py Outdated
Comment thread vllm_rbln/v1/worker/rbln_model_runner.py Outdated
self.cache_config.block_size
] or kernel_block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean cpu kv caching offloading? or does vllm natively support weight offloading?

Comment thread vllm_rbln/v1/worker/rbln_worker.py Outdated
+ num_runtimes fix up
  - ATOM num_runtimes = 2 * VLLM_RBLN_TP_SIZE
  - REBEL num_runtimes = 2 * 4 (quad chiplet)

Signed-off-by: wonsub kim <subang0@rebellions.ai>
+ seq_idx SHOULD be padded if num_reqs < decode_batch size

Signed-off-by: wonsub kim <subang0@rebellions.ai>
# batched attention - seq_lens[B, 1] == seq_idx, original sequence index
# otherwise - seq_lens[B, P] == dyn_size_for_partitions, dynamic size for each partition
if q_len == 1:
if self.is_batch_attention_opt:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the old decode kernel replaced by the new one? If so, please update the custom op's doc and reference impl, and remove is_batch_attention_opt, etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if batch attention opt is used, it provides original sequence index [B, 1] for paged flash attention
otherwise, provide modified seq_lens_tensor [B, P]
provide same interface for custom kernel but compiler can distinguish those

Comment thread vllm_rbln/forward_context.py Outdated

scheduler_config = vllm_config.scheduler_config
max_pad = scheduler_config.max_num_batched_tokens
batchsize = num_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment:
Is this typo?
naming rule might be snake case.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in other codes, batchsize itself is used to represent a word
SHOULD we change all this case?

rebel-wonsubkim and others added 3 commits January 27, 2026 15:02
+ DO NOT count model warm up (prefill & decode batch bucket)

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Signed-off-by: wonsub kim <subang0@rebellions.ai>
- implement specialized decode path that uses optimized padding when
  all requests are in decode stage
- add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable
  specialized handling for decode-only batches in MoE models
- refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor
  to differentiate speicalized decode and normal decode
- add num_padded_tokens parameter to RBLNDPMetadata.make() and
  _set_forward_context()
- add specialized decode path to batch bucketing
@rebel-ykchoi rebel-ykchoi merged commit bbf95a2 into dev-0.12 Jan 27, 2026
1 check passed
@rebel-ykchoi rebel-ykchoi deleted the dev-0.12-merge-gpt-oss branch January 27, 2026 12:07
rebel-jaehwang added a commit that referenced this pull request Jan 30, 2026
* fix: bump up v0 moe dp implementation to v1

- remove DP padding support in v1 worker
- add validation for DP implementation constraints in v1 worker
- apply token mask to custom MOE kernel router logits
- update default environment variables:
  - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode"
  - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True
- fix DP metadata handling in forward context
- add is_prefills field to RBLNFlashAttentionMetadata

* fix: mxfp4 kernel for model parallel

+ add expert_map to handle vllm model parallel

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* modify expert_map position

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix gpt_oss tensor parallel all_reduce

+ gpt_oss MLPBlock tp missing

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* disable shared fused moe overlap for RBLN

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* reference torch impl for gpt-oss ops

* apply VLLM_RBLN_USE_MOE_TOKENS_MASK to mxfp4 MOE

* adjust available dram size based on target arch

+ change available dram size for REBEL architecture
  - ATOM - 16GB
  - REBEL - 140GB

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix v1 dp online serving
refactor: improve intermediate tensors management and dummy run logic

- add prepare_dummy_run and dummy_run methods for v1 dp online serving
- remove unused sync_and_slice_intermediate_tensors method
- separate intermediate_tensors into prefill_intermediate_tensors
  and decode_intermediate_tensors
- improve RBLNWorker device environment initialization
  - add support for Ray backend
  - add local_world_size calculation
  - improve device environment variable setup logic
    - make RBLN_DEVICES not coupled with VLLM_RBLN_TP_SIZE
- change LOCAL_RANK to rank in init_worker_distributed_environment

* add additional params for data_parallel.py script

+ add necessary parameters
  --max-model-len, --block-size, --num-hidden-layers, --decode-batch

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix calculation of maximum num blocks

+ consider sliding window attention
  - DO NOT count sliding window attention block
  since it shares kv cache block with full attention

+ calculate max num blocks based on assumption that
  entire layers have full attention
  - when calculating available memory, count full attention layer
  not sliding window attention

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix: port v0.12 scheduler code

* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* fix DPMetadata for tokens mask

- remove unused attn_metadata parameter from RBLNDPMetadata.make()
- remove is_prefills field and related logic from DP metadata
- fix get_tokens_mask() for non-DP case

* fix dp with pp dummy run logic

- refactor dummy run execution with DummyRunState and prepare_dummy_run
- update batch size calculation to account for pipeline parallel size
- add batch_pad parameter to attention metadata builder for PP support

* fix max_num_blocks calculation

+ consider following issues when calculating max_num_blocks
  - consider gpt-oss-20b scale merge for dequantized version
  - consider SWA(sliding window attention) block share with full attention
  - consider word_embedding param when calculating kernel size
  it is not included into device

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add optimized batch attention kernel

+ batch_attention kernel is optimized version of flash attention kernel for large batch
  - batch attention kernel takes original sequence index
  - in compiler lowering, original sequence index is lowered into following itmes
    - seq_idx - cache target block index
    - seq_offset - cache target block offset
    - dyn_batch - valid batch count for each partition

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* resolve conflict between bucketing and dp

+ replace max_batch_size with decode_batch_bucket size
+ by default, disable batch bucketing
  - change limit of bucket

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix num_runtimes

+ num_runtimes fix up
  - ATOM num_runtimes = 2 * VLLM_RBLN_TP_SIZE
  - REBEL num_runtimes = 2 * 4 (quad chiplet)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* pad seq_idx for batch attention

+ seq_idx SHOULD be padded if num_reqs < decode_batch size

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fixed batched decode func call

* remove unused code

* fix up RBLN_METRICS

+ DO NOT count model warm up (prefill & decode batch bucket)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix typo

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add specialized MoE decode optimization for DP

- implement specialized decode path that uses optimized padding when
  all requests are in decode stage
- add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable
  specialized handling for decode-only batches in MoE models
- refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor
  to differentiate speicalized decode and normal decode
- add num_padded_tokens parameter to RBLNDPMetadata.make() and
  _set_forward_context()
- add specialized decode path to batch bucketing

---------

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Co-authored-by: Youngkyu Choi <youngkyu.choi@rebellions.ai>
Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
Co-authored-by: Huijong JEONG <huijong.jeong@squeezebits.com>
Co-authored-by: JaehunRyu <jaehun.ryu@rebellions.ai>
rebel-jaehwang added a commit that referenced this pull request Jan 30, 2026
* fix: bump up v0 moe dp implementation to v1

- remove DP padding support in v1 worker
- add validation for DP implementation constraints in v1 worker
- apply token mask to custom MOE kernel router logits
- update default environment variables:
  - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode"
  - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True
- fix DP metadata handling in forward context
- add is_prefills field to RBLNFlashAttentionMetadata

* fix: mxfp4 kernel for model parallel

+ add expert_map to handle vllm model parallel

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* modify expert_map position

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix gpt_oss tensor parallel all_reduce

+ gpt_oss MLPBlock tp missing

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* disable shared fused moe overlap for RBLN

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* reference torch impl for gpt-oss ops

* apply VLLM_RBLN_USE_MOE_TOKENS_MASK to mxfp4 MOE

* adjust available dram size based on target arch

+ change available dram size for REBEL architecture
  - ATOM - 16GB
  - REBEL - 140GB

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix v1 dp online serving
refactor: improve intermediate tensors management and dummy run logic

- add prepare_dummy_run and dummy_run methods for v1 dp online serving
- remove unused sync_and_slice_intermediate_tensors method
- separate intermediate_tensors into prefill_intermediate_tensors
  and decode_intermediate_tensors
- improve RBLNWorker device environment initialization
  - add support for Ray backend
  - add local_world_size calculation
  - improve device environment variable setup logic
    - make RBLN_DEVICES not coupled with VLLM_RBLN_TP_SIZE
- change LOCAL_RANK to rank in init_worker_distributed_environment

* add additional params for data_parallel.py script

+ add necessary parameters
  --max-model-len, --block-size, --num-hidden-layers, --decode-batch

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix calculation of maximum num blocks

+ consider sliding window attention
  - DO NOT count sliding window attention block
  since it shares kv cache block with full attention

+ calculate max num blocks based on assumption that
  entire layers have full attention
  - when calculating available memory, count full attention layer
  not sliding window attention

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix: port v0.12 scheduler code

* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* fix DPMetadata for tokens mask

- remove unused attn_metadata parameter from RBLNDPMetadata.make()
- remove is_prefills field and related logic from DP metadata
- fix get_tokens_mask() for non-DP case

* fix dp with pp dummy run logic

- refactor dummy run execution with DummyRunState and prepare_dummy_run
- update batch size calculation to account for pipeline parallel size
- add batch_pad parameter to attention metadata builder for PP support

* fix max_num_blocks calculation

+ consider following issues when calculating max_num_blocks
  - consider gpt-oss-20b scale merge for dequantized version
  - consider SWA(sliding window attention) block share with full attention
  - consider word_embedding param when calculating kernel size
  it is not included into device

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add optimized batch attention kernel

+ batch_attention kernel is optimized version of flash attention kernel for large batch
  - batch attention kernel takes original sequence index
  - in compiler lowering, original sequence index is lowered into following itmes
    - seq_idx - cache target block index
    - seq_offset - cache target block offset
    - dyn_batch - valid batch count for each partition

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* resolve conflict between bucketing and dp

+ replace max_batch_size with decode_batch_bucket size
+ by default, disable batch bucketing
  - change limit of bucket

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix num_runtimes

+ num_runtimes fix up
  - ATOM num_runtimes = 2 * VLLM_RBLN_TP_SIZE
  - REBEL num_runtimes = 2 * 4 (quad chiplet)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* pad seq_idx for batch attention

+ seq_idx SHOULD be padded if num_reqs < decode_batch size

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fixed batched decode func call

* remove unused code

* fix up RBLN_METRICS

+ DO NOT count model warm up (prefill & decode batch bucket)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix typo

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add specialized MoE decode optimization for DP

- implement specialized decode path that uses optimized padding when
  all requests are in decode stage
- add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable
  specialized handling for decode-only batches in MoE models
- refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor
  to differentiate speicalized decode and normal decode
- add num_padded_tokens parameter to RBLNDPMetadata.make() and
  _set_forward_context()
- add specialized decode path to batch bucketing

---------

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Co-authored-by: Youngkyu Choi <youngkyu.choi@rebellions.ai>
Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
Co-authored-by: Huijong JEONG <huijong.jeong@squeezebits.com>
Co-authored-by: JaehunRyu <jaehun.ryu@rebellions.ai>
rebel-jaehwang added a commit that referenced this pull request Jan 30, 2026
* fix: bump up v0 moe dp implementation to v1

- remove DP padding support in v1 worker
- add validation for DP implementation constraints in v1 worker
- apply token mask to custom MOE kernel router logits
- update default environment variables:
  - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode"
  - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True
- fix DP metadata handling in forward context
- add is_prefills field to RBLNFlashAttentionMetadata

* fix: mxfp4 kernel for model parallel

+ add expert_map to handle vllm model parallel

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* modify expert_map position

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix gpt_oss tensor parallel all_reduce

+ gpt_oss MLPBlock tp missing

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* disable shared fused moe overlap for RBLN

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* reference torch impl for gpt-oss ops

* apply VLLM_RBLN_USE_MOE_TOKENS_MASK to mxfp4 MOE

* adjust available dram size based on target arch

+ change available dram size for REBEL architecture
  - ATOM - 16GB
  - REBEL - 140GB

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix v1 dp online serving
refactor: improve intermediate tensors management and dummy run logic

- add prepare_dummy_run and dummy_run methods for v1 dp online serving
- remove unused sync_and_slice_intermediate_tensors method
- separate intermediate_tensors into prefill_intermediate_tensors
  and decode_intermediate_tensors
- improve RBLNWorker device environment initialization
  - add support for Ray backend
  - add local_world_size calculation
  - improve device environment variable setup logic
    - make RBLN_DEVICES not coupled with VLLM_RBLN_TP_SIZE
- change LOCAL_RANK to rank in init_worker_distributed_environment

* add additional params for data_parallel.py script

+ add necessary parameters
  --max-model-len, --block-size, --num-hidden-layers, --decode-batch

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix calculation of maximum num blocks

+ consider sliding window attention
  - DO NOT count sliding window attention block
  since it shares kv cache block with full attention

+ calculate max num blocks based on assumption that
  entire layers have full attention
  - when calculating available memory, count full attention layer
  not sliding window attention

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix: port v0.12 scheduler code

* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* fix DPMetadata for tokens mask

- remove unused attn_metadata parameter from RBLNDPMetadata.make()
- remove is_prefills field and related logic from DP metadata
- fix get_tokens_mask() for non-DP case

* fix dp with pp dummy run logic

- refactor dummy run execution with DummyRunState and prepare_dummy_run
- update batch size calculation to account for pipeline parallel size
- add batch_pad parameter to attention metadata builder for PP support

* fix max_num_blocks calculation

+ consider following issues when calculating max_num_blocks
  - consider gpt-oss-20b scale merge for dequantized version
  - consider SWA(sliding window attention) block share with full attention
  - consider word_embedding param when calculating kernel size
  it is not included into device

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add optimized batch attention kernel

+ batch_attention kernel is optimized version of flash attention kernel for large batch
  - batch attention kernel takes original sequence index
  - in compiler lowering, original sequence index is lowered into following itmes
    - seq_idx - cache target block index
    - seq_offset - cache target block offset
    - dyn_batch - valid batch count for each partition

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* resolve conflict between bucketing and dp

+ replace max_batch_size with decode_batch_bucket size
+ by default, disable batch bucketing
  - change limit of bucket

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix num_runtimes

+ num_runtimes fix up
  - ATOM num_runtimes = 2 * VLLM_RBLN_TP_SIZE
  - REBEL num_runtimes = 2 * 4 (quad chiplet)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* pad seq_idx for batch attention

+ seq_idx SHOULD be padded if num_reqs < decode_batch size

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fixed batched decode func call

* remove unused code

* fix up RBLN_METRICS

+ DO NOT count model warm up (prefill & decode batch bucket)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix typo

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add specialized MoE decode optimization for DP

- implement specialized decode path that uses optimized padding when
  all requests are in decode stage
- add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable
  specialized handling for decode-only batches in MoE models
- refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor
  to differentiate speicalized decode and normal decode
- add num_padded_tokens parameter to RBLNDPMetadata.make() and
  _set_forward_context()
- add specialized decode path to batch bucketing

---------

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Co-authored-by: Youngkyu Choi <youngkyu.choi@rebellions.ai>
Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
Co-authored-by: Huijong JEONG <huijong.jeong@squeezebits.com>
Co-authored-by: JaehunRyu <jaehun.ryu@rebellions.ai>
rebel-jiwoopark pushed a commit that referenced this pull request Feb 4, 2026
* fix: bump up v0 moe dp implementation to v1

- remove DP padding support in v1 worker
- add validation for DP implementation constraints in v1 worker
- apply token mask to custom MOE kernel router logits
- update default environment variables:
  - VLLM_RBLN_DP_IMPL: "dummy_prefill" -> "padded_decode"
  - VLLM_RBLN_USE_MOE_TOKENS_MASK: False -> True
- fix DP metadata handling in forward context
- add is_prefills field to RBLNFlashAttentionMetadata

* fix: mxfp4 kernel for model parallel

+ add expert_map to handle vllm model parallel

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* modify expert_map position

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix gpt_oss tensor parallel all_reduce

+ gpt_oss MLPBlock tp missing

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* disable shared fused moe overlap for RBLN

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* reference torch impl for gpt-oss ops

* apply VLLM_RBLN_USE_MOE_TOKENS_MASK to mxfp4 MOE

* adjust available dram size based on target arch

+ change available dram size for REBEL architecture
  - ATOM - 16GB
  - REBEL - 140GB

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix v1 dp online serving
refactor: improve intermediate tensors management and dummy run logic

- add prepare_dummy_run and dummy_run methods for v1 dp online serving
- remove unused sync_and_slice_intermediate_tensors method
- separate intermediate_tensors into prefill_intermediate_tensors
  and decode_intermediate_tensors
- improve RBLNWorker device environment initialization
  - add support for Ray backend
  - add local_world_size calculation
  - improve device environment variable setup logic
    - make RBLN_DEVICES not coupled with VLLM_RBLN_TP_SIZE
- change LOCAL_RANK to rank in init_worker_distributed_environment

* add additional params for data_parallel.py script

+ add necessary parameters
  --max-model-len, --block-size, --num-hidden-layers, --decode-batch

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix calculation of maximum num blocks

+ consider sliding window attention
  - DO NOT count sliding window attention block
  since it shares kv cache block with full attention

+ calculate max num blocks based on assumption that
  entire layers have full attention
  - when calculating available memory, count full attention layer
  not sliding window attention

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix: port v0.12 scheduler code

* fix: limit decode bs to (max num seqs // pp size)

* tmp: pad decode inputs to max_num_seqs // pp_size

* add: simple offline benchmark script

* fix DPMetadata for tokens mask

- remove unused attn_metadata parameter from RBLNDPMetadata.make()
- remove is_prefills field and related logic from DP metadata
- fix get_tokens_mask() for non-DP case

* fix dp with pp dummy run logic

- refactor dummy run execution with DummyRunState and prepare_dummy_run
- update batch size calculation to account for pipeline parallel size
- add batch_pad parameter to attention metadata builder for PP support

* fix max_num_blocks calculation

+ consider following issues when calculating max_num_blocks
  - consider gpt-oss-20b scale merge for dequantized version
  - consider SWA(sliding window attention) block share with full attention
  - consider word_embedding param when calculating kernel size
  it is not included into device

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add optimized batch attention kernel

+ batch_attention kernel is optimized version of flash attention kernel for large batch
  - batch attention kernel takes original sequence index
  - in compiler lowering, original sequence index is lowered into following itmes
    - seq_idx - cache target block index
    - seq_offset - cache target block offset
    - dyn_batch - valid batch count for each partition

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* resolve conflict between bucketing and dp

+ replace max_batch_size with decode_batch_bucket size
+ by default, disable batch bucketing
  - change limit of bucket

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix num_runtimes

+ num_runtimes fix up
  - ATOM num_runtimes = 2 * VLLM_RBLN_TP_SIZE
  - REBEL num_runtimes = 2 * 4 (quad chiplet)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* pad seq_idx for batch attention

+ seq_idx SHOULD be padded if num_reqs < decode_batch size

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fixed batched decode func call

* remove unused code

* fix up RBLN_METRICS

+ DO NOT count model warm up (prefill & decode batch bucket)

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* fix typo

Signed-off-by: wonsub kim <subang0@rebellions.ai>

* add specialized MoE decode optimization for DP

- implement specialized decode path that uses optimized padding when
  all requests are in decode stage
- add VLLM_RBLN_SPECIALIZE_MOE_DECODE environment variable to enable
  specialized handling for decode-only batches in MoE models
- refactor RBLNDPMetadata.max_pads_across_dp from int to torch.Tensor
  to differentiate speicalized decode and normal decode
- add num_padded_tokens parameter to RBLNDPMetadata.make() and
  _set_forward_context()
- add specialized decode path to batch bucketing

---------

Signed-off-by: wonsub kim <subang0@rebellions.ai>
Co-authored-by: Youngkyu Choi <youngkyu.choi@rebellions.ai>
Co-authored-by: Jaehwang Jung <jaehwang.jung@rebellions.ai>
Co-authored-by: Huijong JEONG <huijong.jeong@squeezebits.com>
Co-authored-by: JaehunRyu <jaehun.ryu@rebellions.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

torch.compile torch.compile based implementation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants