Skip to content

Commit e81a3d2

Browse files
authored
Quickfix for global context (#444)
1 parent 6278c7c commit e81a3d2

2 files changed

Lines changed: 14 additions & 1 deletion

File tree

vllm_rbln/v1/sample/rbln_sampler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
# isort: off
1515
import torch
1616
import torch.nn as nn
17+
18+
try:
19+
import torch.rbln
20+
21+
has_torch_rbln = True
22+
except ImportError:
23+
has_torch_rbln = False
24+
1725
from vllm_rbln.logger import init_logger
1826
from vllm.v1.sample.metadata import SamplingMetadata
1927
from vllm.v1.sample.sampler import Sampler as VLLMSampler
@@ -151,6 +159,11 @@ def __init__(
151159
}
152160
if envs.VLLM_RBLN_COMPILE_STRICT_MODE:
153161
options["mode"] = "strict"
162+
163+
if has_torch_rbln:
164+
options["use_global_ctx"] = True
165+
options["global_device_id"] = 0
166+
154167
self._compiled_rbln_topk_topp_sampler = torch.compile(
155168
rbln_top_k_top_p_sample,
156169
dynamic=False,

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,7 @@ def _compile_model(self, model):
13511351
)
13521352
options["cache_dir"] = os.path.join(envs.VLLM_CACHE_ROOT, "rbln")
13531353

1354-
if envs.VLLM_RBLN_AUTO_PORT and has_torch_rbln:
1354+
if has_torch_rbln:
13551355
options["use_global_ctx"] = True
13561356
# TODO(yunseong.kim): use device_id from current platform
13571357
# when vllm-rbln supports it

0 commit comments

Comments
 (0)