-
Notifications
You must be signed in to change notification settings - Fork 12
other: improve optimum-compile #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rebel-seinpark
wants to merge
76
commits into
dev
Choose a base branch
from
add-pytest
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 75 commits
Commits
Show all changes
76 commits
Select commit
Hold shift + click to select a range
370ddad
allow to use rbln_config in addtional_config
rebel-eunji 4598c28
remove print
rebel-eunji 9212557
fix block_size in pooling models
rebel-eunji e50ed36
update qwen3_vl and whisper
rebel-eunji 7699953
idefics3
rebel-eunji 6db53ce
fix multimodal
rebel-eunji edb4162
apply rbln
rebel-eunji b177fc2
add runtime exception
rebel-eunji 7cbec8b
fix rbln_config
rebel-eunji 580d7ca
sync in compile
rebel-eunji 3bb4304
fix unpack
rebel-eunji 9ffc5cf
align the setting param after compilation
rebel-eunji b296182
fix bug
rebel-eunji e505b2a
refactor
rebel-eunji d0a6b9b
fix
rebel-eunji 0253787
refactor rbln_params
rebel-eunji c809f88
remove unused class
rebel-eunji 0f6dea6
remove duplicate one
rebel-eunji 73fc7a5
decouple the logic
rebel-eunji 9f3a108
prepare_vllm_for_compile logic
rebel-eunji 1233806
refactor sync_to_vllm, sync_from_vllm
rebel-eunji b2cce5c
fix utils -> common and pre-commit
rebel-eunji 29fef76
decouple registry and compile logic
rebel-eunji 29cdc14
refactor
rebel-eunji e9c9948
decouple compile and mapping value logic
rebel-eunji f0b3f4a
move multimodal and refactor block setting
rebel-eunji c1cc8cb
move logics - compile
rebel-eunji cef3a2b
deprecate configuration.py
rebel-eunji f8fe930
refactor
rebel-eunji 33a0cfc
remove helper py
rebel-eunji 544a7ee
fix sync from vllm/optimum
rebel-eunji 2984334
bucketing
rebel-eunji bdb1299
fix sync from vllm/optimum
rebel-eunji d5624a9
move to optimum
rebel-eunji 5dea678
refactor
rebel-eunji 9e33950
add comment
rebel-eunji 04732fd
decouple
rebel-eunji 7b82f19
fix pytest
rebel-eunji 852b6cb
downgrade optimum-rbln version (tmp)
rebel-eunji 6fa6351
remove validate
rebel-eunji 523bd3f
fix log
rebel-eunji 66e4fa0
Apply suggestion from @rebel-eunji
rebel-eunji 97a336e
Apply suggestion from @rebel-eunji
rebel-eunji 7e2c70a
add missing part to sync max_num_batched_tokens
rebel-eunji c5ddca8
refactor init_model
rebel-eunji b479655
fix model_path bug
rebel-eunji dc57c01
move compilation
rebel-eunji c9e6623
make dispatch.py slimmer
rebel-eunji 33be2f2
clarify rbln_config of optimum and addtional_config[rbln_config]
rebel-eunji 142bede
fix make clear of dispatch
rebel-eunji 039d61d
log the parameters for compilation
rebel-eunji 667f3df
update default block_size
rebel-eunji eb87b8a
deep merge
rebel-eunji 3086d90
use optimum flexible prefill_chunk_size
rebel-eunji c16c2f0
Merge branch 'dev' into update/compile-model-optimum
rebel-eunji dcba306
pre-commit
rebel-eunji df3cadb
fix block size
rebel-eunji 09860da
remove print
rebel-eunji 6e30f5f
Update vllm_rbln/utils/optimum/converter/from_vllm.py
rebel-eunji cc238ab
change to force setting block_size
rebel-eunji d4b93c3
boolean user_specified_block_size
rebel-eunji cd9e8a4
fix conftest.py to skip sync_vllm_and_optimum
rebel-eunji e18f2c8
revert
rebel-eunji 1a56f5f
fix conftest.py
rebel-eunji cef437e
refactor param parsing and add pytest
rebel-eunji 734d57c
apply pre-commit fix
rebel-eunji 857c6e0
remove assertion of block_size
rebel-eunji 4311357
add pytest
rebel-eunji 54833c1
fix command
rebel-eunji 5f1f019
fix test
rebel-eunji d5a6860
available block_size to pooling, enc-dec
rebel-eunji 729b9c7
add pytest
rebel-seinpark d44ea5c
fix
rebel-seinpark 93a4a5f
rm
rebel-seinpark 9115ba9
update example models
rebel-seinpark 4bf7f01
Merge branch 'dev' into add-pytest
rebel-seinpark File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2025 Rebellions Inc. All rights reserved. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at: | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2025 Rebellions Inc. All rights reserved. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at: | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright 2025 Rebellions Inc. All rights reserved. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at: | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
283 changes: 283 additions & 0 deletions
283
tests/model_executor/models/optimum/test_compilation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,283 @@ | ||
| # Copyright 2025 Rebellions Inc. All rights reserved. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at: | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
| from optimum.rbln import ( | ||
| RBLNAutoModelForCausalLM, | ||
| RBLNAutoModelForImageTextToText, | ||
| RBLNAutoModelForSpeechSeq2Seq, | ||
| RBLNAutoModelForVision2Seq, | ||
| RBLNBertModel, | ||
| RBLNQwen3Model, | ||
| ) | ||
|
|
||
| from vllm_rbln.model_executor.models.optimum import compilation | ||
| from vllm_rbln.model_executor.models.optimum.compilation import ( | ||
| RBLNCompileSpec, | ||
| _deep_merge, | ||
| ) | ||
|
|
||
|
|
||
| def _hf(arch: str, **extra) -> SimpleNamespace: | ||
| return SimpleNamespace(architectures=[arch], **extra) | ||
|
|
||
|
|
||
| class TestDeepMerge: | ||
| def test_top_level_overwrite(self): | ||
| base = {"a": 1, "b": 2} | ||
| _deep_merge(base, {"b": 99}) | ||
| assert base == {"a": 1, "b": 99} | ||
|
|
||
| def test_nested_merge_preserves_untouched_subkeys(self): | ||
| base = {"language_model": {"batch_size": 4, "max_seq_len": 1024}} | ||
| _deep_merge(base, {"language_model": {"max_seq_len": 2048}}) | ||
| assert base == { | ||
| "language_model": {"batch_size": 4, "max_seq_len": 2048} | ||
| } | ||
|
|
||
| def test_nondict_overrides_dict(self): | ||
| base = {"x": {"nested": True}} | ||
| _deep_merge(base, {"x": "scalar"}) | ||
| assert base == {"x": "scalar"} | ||
|
|
||
| def test_new_keys_added(self): | ||
| base = {"a": 1} | ||
| _deep_merge(base, {"b": 2}) | ||
| assert base == {"a": 1, "b": 2} | ||
|
|
||
| def test_empty_overrides_is_noop(self): | ||
| base = {"a": {"b": 1}} | ||
| _deep_merge(base, {}) | ||
| assert base == {"a": {"b": 1}} | ||
|
|
||
|
|
||
| class TestForArchitectureDispatch: | ||
| def test_unknown_architecture_raises(self): | ||
| with pytest.raises(NotImplementedError): | ||
| RBLNCompileSpec.for_architecture( | ||
| _hf("DefinitelyNotARealArch"), | ||
| batch_size=1, | ||
| block_size=128, | ||
| max_model_len=128, | ||
| tp_size=1, | ||
| ) | ||
|
|
||
| def test_generation_dispatches_to_decoder(self): | ||
| spec = RBLNCompileSpec.for_architecture( | ||
| _hf("LlamaForCausalLM"), | ||
| batch_size=4, | ||
| block_size=128, | ||
| max_model_len=1024, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNAutoModelForCausalLM | ||
|
|
||
| def test_pooling_dispatches_to_pooling(self): | ||
| spec = RBLNCompileSpec.for_architecture( | ||
| _hf("BertModel"), | ||
| batch_size=4, | ||
| block_size=128, | ||
| max_model_len=128, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNBertModel | ||
|
|
||
| def test_multimodal_dispatches_to_multimodal(self): | ||
| spec = RBLNCompileSpec.for_architecture( | ||
| _hf("LlavaForConditionalGeneration"), | ||
| batch_size=2, | ||
| block_size=128, | ||
| max_model_len=2048, | ||
| tp_size=1, | ||
| ) | ||
| # LlavaForConditionalGeneration -> RBLNAutoModelForVision2Seq. | ||
| assert spec.model_cls is RBLNAutoModelForVision2Seq | ||
|
|
||
| def test_gemma3_multimodal_uses_image_text_to_text(self): | ||
| spec = RBLNCompileSpec.for_architecture( | ||
| _hf("Gemma3ForConditionalGeneration"), | ||
| batch_size=2, | ||
| block_size=128, | ||
| max_model_len=2048, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNAutoModelForImageTextToText | ||
|
|
||
| def test_enc_dec_dispatches_to_enc_dec(self): | ||
| spec = RBLNCompileSpec.for_architecture( | ||
| _hf("WhisperForConditionalGeneration", max_length=448), | ||
| batch_size=2, | ||
| block_size=448, | ||
| max_model_len=448, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNAutoModelForSpeechSeq2Seq | ||
|
|
||
| def test_rbln_overrides_are_deep_merged(self): | ||
| spec = RBLNCompileSpec.for_architecture( | ||
| _hf("LlamaForCausalLM"), | ||
| batch_size=4, | ||
| block_size=128, | ||
| max_model_len=1024, | ||
| tp_size=1, | ||
| rbln_overrides={"batch_size": 9, "extra_key": "value"}, | ||
| ) | ||
| assert spec.rbln_config["batch_size"] == 9 # overridden | ||
| assert spec.rbln_config["extra_key"] == "value" # added | ||
| assert spec.rbln_config["max_seq_len"] == 1024 # untouched | ||
|
|
||
|
|
||
| class TestForDecoder: | ||
| def test_no_partition_when_block_size_equals_max_model_len(self): | ||
| spec = RBLNCompileSpec._for_decoder( | ||
| batch_size=4, block_size=1024, max_model_len=1024, tp_size=1 | ||
| ) | ||
| assert spec.rbln_config == { | ||
| "tensor_parallel_size": 1, | ||
| "batch_size": 4, | ||
| "max_seq_len": 1024, | ||
| } | ||
|
|
||
| def test_flash_attn_when_block_size_smaller_than_max_model_len(self): | ||
| spec = RBLNCompileSpec._for_decoder( | ||
| batch_size=4, block_size=128, max_model_len=1024, tp_size=2 | ||
| ) | ||
| assert spec.rbln_config == { | ||
| "tensor_parallel_size": 2, | ||
| "batch_size": 4, | ||
| "max_seq_len": 1024, | ||
| "kvcache_partition_len": 128, | ||
| "attn_impl": "flash_attn", | ||
| } | ||
|
|
||
|
|
||
| class TestForPooling: | ||
| def test_non_qwen3_no_flash_attn_even_when_block_size_differs(self): | ||
| spec = RBLNCompileSpec._for_pooling( | ||
| _hf("BertModel"), | ||
| batch_size=4, | ||
| block_size=128, | ||
| max_model_len=512, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNBertModel | ||
| assert "kvcache_partition_len" not in spec.rbln_config | ||
| assert "attn_impl" not in spec.rbln_config | ||
|
|
||
| def test_qwen3_model_with_smaller_block_uses_flash_attn(self): | ||
| spec = RBLNCompileSpec._for_pooling( | ||
| _hf("Qwen3Model"), | ||
| batch_size=4, | ||
| block_size=128, | ||
| max_model_len=2048, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNQwen3Model | ||
| assert spec.rbln_config["kvcache_partition_len"] == 128 | ||
| assert spec.rbln_config["attn_impl"] == "flash_attn" | ||
|
|
||
| def test_qwen3_model_no_flash_attn_when_block_equals_max(self): | ||
| spec = RBLNCompileSpec._for_pooling( | ||
| _hf("Qwen3Model"), | ||
| batch_size=4, | ||
| block_size=512, | ||
| max_model_len=512, | ||
| tp_size=1, | ||
| ) | ||
| assert "kvcache_partition_len" not in spec.rbln_config | ||
| assert "attn_impl" not in spec.rbln_config | ||
|
|
||
|
|
||
| class TestForEncDec: | ||
| def test_happy_path_produces_whisper_spec(self): | ||
| spec = RBLNCompileSpec._for_enc_dec( | ||
| _hf("WhisperForConditionalGeneration", max_length=448), | ||
| batch_size=2, | ||
| block_size=448, | ||
| max_model_len=448, | ||
| tp_size=1, | ||
| ) | ||
| assert spec.model_cls is RBLNAutoModelForSpeechSeq2Seq | ||
| assert spec.rbln_config == { | ||
| "tensor_parallel_size": 1, | ||
| "batch_size": 2, | ||
| "token_timestamps": False, | ||
| } | ||
|
|
||
| def test_block_size_must_equal_max_model_len(self): | ||
| with pytest.raises(AssertionError, match="block_size"): | ||
| RBLNCompileSpec._for_enc_dec( | ||
| _hf("WhisperForConditionalGeneration", max_length=448), | ||
| batch_size=2, | ||
| block_size=128, | ||
| max_model_len=448, | ||
| tp_size=1, | ||
| ) | ||
|
|
||
| def test_max_model_len_must_match_hf_max_length(self): | ||
| with pytest.raises(AssertionError, match="max_length"): | ||
| RBLNCompileSpec._for_enc_dec( | ||
| _hf("WhisperForConditionalGeneration", max_length=448), | ||
| batch_size=2, | ||
| block_size=512, | ||
| max_model_len=512, | ||
| tp_size=1, | ||
| ) | ||
|
|
||
|
|
||
| class TestForMultimodal: | ||
| def test_unknown_alias_raises(self, monkeypatch): | ||
| # Force get_rbln_model_info to return a model alias missing from | ||
| # _COMPILE_MULTIMODAL_FNS. | ||
| monkeypatch.setattr( | ||
| compilation, | ||
| "get_rbln_model_info", | ||
| lambda config: ("definitely_unknown_alias", "RBLNDoesntMatter"), | ||
| ) | ||
| with pytest.raises(ValueError, match="multimodal model alias"): | ||
| RBLNCompileSpec._for_multimodal( | ||
| _hf("LlavaForConditionalGeneration"), | ||
| batch_size=2, | ||
| block_size=128, | ||
| max_model_len=2048, | ||
| tp_size=1, | ||
| ) | ||
|
|
||
| def test_dispatches_to_compile_fn_with_forwarded_args(self, monkeypatch): | ||
| captured = {} | ||
|
|
||
| def fake_compile_fn(batch_size, max_model_len, block_size, tp_size): | ||
| captured["args"] = (batch_size, max_model_len, block_size, tp_size) | ||
| return {"sentinel": True} | ||
|
|
||
| # Patch the dispatch table on the imported module so the real fn | ||
| # doesn't run (and so the assertion can compare without aliasing). | ||
| monkeypatch.setitem( | ||
| compilation._COMPILE_MULTIMODAL_FNS, "llava", fake_compile_fn | ||
| ) | ||
|
|
||
| spec = RBLNCompileSpec._for_multimodal( | ||
| _hf("LlavaForConditionalGeneration"), | ||
| batch_size=2, | ||
| block_size=128, | ||
| max_model_len=2048, | ||
| tp_size=4, | ||
| ) | ||
| # Note the unusual argument order in `_for_multimodal`: | ||
| # (batch_size, max_model_len, block_size, tp_size). | ||
| assert captured["args"] == (2, 2048, 128, 4) | ||
| assert spec.rbln_config == {"sentinel": True} | ||
| assert spec.model_cls is RBLNAutoModelForVision2Seq | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this test for?