Skip to content

Commit 5ddd2a2

Browse files
rebel-jonghewkrebel-jonghewk
andauthored
fix: Re-enable standalone ctx (#407)
* Revert "Revert "fix(core): Sampler with `RBLN_CTX_STANDALONE` (#401)" (#405)" This reverts commit 81f7d47. * Revert "fix RBLN_CTX_STANDALONE=0" This reverts commit defd60e. --------- Co-authored-by: rebel-jonghewk <jonghewk@rebellions.in>
1 parent bca048d commit 5ddd2a2

3 files changed

Lines changed: 24 additions & 7 deletions

File tree

vllm_rbln/platform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def validate_and_setup_prerequisite(cls, vllm_config: VllmConfig) -> None:
135135
"RBLN_PROFILER is not supported when using vLLM model parallel "
136136
"(TP, DP, EP, or PP)."
137137
)
138-
os.environ["RBLN_CTX_STANDALONE"] = "0"
138+
os.environ["RBLN_CTX_STANDALONE"] = "1"
139139
os.environ["RBLN_FORCE_CCL_ASYNC"] = "1"
140140

141141
@classmethod

vllm_rbln/v1/sample/rbln_sampler.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,12 @@ def rbln_top_k_top_p_sample(
128128

129129

130130
class RBLNTopKTopPSampler(nn.Module):
131-
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs", seed: int = 42):
131+
def __init__(
132+
self,
133+
logprobs_mode: LogprobsMode = "raw_logprobs",
134+
seed: int = 42,
135+
compile_context: rebel.CompileContext = None,
136+
):
132137
# TODO(rbln): Merge more ops to rbln context.
133138
# Currently, we only have softmax in rbln context.
134139
super().__init__()
@@ -139,7 +144,11 @@ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs", seed: int = 42)
139144
)
140145

141146
rebel.manual_seed(seed)
142-
options = {"compile_context": rebel.CompileContext()}
147+
options = {
148+
"compile_context": compile_context
149+
if compile_context
150+
else rebel.CompileContext()
151+
}
143152
if envs.VLLM_RBLN_COMPILE_STRICT_MODE:
144153
options["mode"] = "strict"
145154
self._compiled_rbln_topk_topp_sampler = torch.compile(
@@ -175,11 +184,16 @@ def forward_rbln(
175184

176185

177186
class RBLNSampler(VLLMSampler):
178-
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs", seed: int = 42):
187+
def __init__(
188+
self,
189+
logprobs_mode: LogprobsMode = "raw_logprobs",
190+
seed: int = 42,
191+
compile_context: rebel.CompileContext = None,
192+
):
179193
super().__init__()
180194
if logprobs_mode in ("raw_logprobs", "raw_logits"):
181195
self.topk_topp_sampler = RBLNTopKTopPSampler(
182-
logprobs_mode=logprobs_mode, seed=seed
196+
logprobs_mode=logprobs_mode, seed=seed, compile_context=compile_context
183197
)
184198
else:
185199
logger.warning_once(

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,18 @@ def __init__(
272272
else:
273273
self.max_encoder_len = 0
274274

275+
from rebel.compile_context import CompileContext
276+
277+
self.compile_context = CompileContext(use_weight_sharing=True)
278+
275279
# Sampler
276280
self.use_rbln_sampler = envs.VLLM_RBLN_SAMPLER
277281
if self.use_rbln_sampler:
278282
logger.info("Using RBLN sampler: %s", self.use_rbln_sampler)
279283
sampler = RBLNSampler(
280284
logprobs_mode=self.model_config.logprobs_mode,
281285
seed=self.vllm_config.model_config.seed,
286+
compile_context=self.compile_context,
282287
)
283288
else:
284289
logger.info("Using default vLLM sampler.")
@@ -3211,9 +3216,7 @@ def model_wrapper(
32113216
# RBLN compile context to mark static address for kv cache tensor
32123217
# if tensor is set to have static address,
32133218
# similar to RBLN kv cache binding
3214-
from rebel.compile_context import CompileContext
32153219

3216-
self.compile_context = CompileContext(use_weight_sharing=True)
32173220
compiled_graph = self._compile_model(model_wrapper)
32183221
self.model_executable = compiled_graph
32193222

0 commit comments

Comments
 (0)