Skip to content
Open
126 changes: 126 additions & 0 deletions scripts/performance/measure_latency_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Reference - https://github.com/vllm-project/vllm/blob/v0.9.1/benchmarks/benchmark_throughput.py
import argparse
import os
import time
import urllib.request
from typing import TYPE_CHECKING, Any

import torch
from transformers import AutoTokenizer

if TYPE_CHECKING:
from vllm import SamplingParams
from vllm.outputs import RequestOutput

MODEL_NAME = "meta-llama/Llama-3.2-1B"
PREFILL_CHUNK_SIZE = 128


def get_wiki_prompt():
wiki_txt_url = "https://raw.githubusercontent.com/huggingface/optimum-neuron/refs/heads/main/benchmark/text-generation/performance/wiki.txt"
with urllib.request.urlopen(wiki_txt_url) as resp:
source_data = resp.read().decode("utf-8")
return source_data


def generate_llm_args(batch_size: int):
return {
"model": "meta-llama/Llama-3.2-1B",
"max_model_len": 40 * 1024,
"enable_chunked_prefill": True,
"max_num_seqs": batch_size,
"block_size": 1024,
"max_num_batched_tokens": PREFILL_CHUNK_SIZE,
}


def generate_prompts(prompt_length: int, batch_size: int) -> list[str]:
wiki_prompt = get_wiki_prompt()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokens = tokenizer(wiki_prompt, return_tensors="pt").input_ids[0]
assert len(tokens) > prompt_length * batch_size
prompts = []
# Leave 1 token for special token(bos) in the vllm
real_prompt_length = prompt_length - 1
for i in range(batch_size):
start_pos = i * real_prompt_length
end_pos = (i + 1) * real_prompt_length
prompt = tokenizer.decode(tokens[start_pos:end_pos])
prompts.append(prompt)
return prompts


def run_llm(
llm, prompts: list[str], sampling_params: "SamplingParams"
) -> tuple[float, list["RequestOutput"]]:
start = time.perf_counter()
outputs = llm.generate(prompts, sampling_params=sampling_params)
end = time.perf_counter()
elapsed_time = end - start
return elapsed_time, outputs


def _worker(prompts: list[str], args: Any):
llm_args = generate_llm_args(args.batch_size)
os.environ["VLLM_RBLN_METRICS"] = "1"
os.environ.pop("VLLM_PLUGINS", None)
os.environ["RBLN_KERNEL_MODE"] = "triton"
os.environ["VLLM_USE_V1"] = "0"
os.environ["USE_VLLM_MODEL"] = "1"
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "0"
# 1 means disable using compile cache
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.max_tokens,
)
total_elapsed_time = 0.0
# FIXME: In rbln, re-initializing LLM
# in each iteration triggers runtime error:
# (Runtime) code=203 INIT_ALREADY_CREATED:
# A runtime has already been created for that compiled model
# (Context failed to be created, compile_id=0).
# Try creating a runtime on a different NPU(s), or use an existing runtime.
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This FIXME comment describes a known issue but should include a reference to a tracking issue or ticket number for resolution.

Suggested change
# Try creating a runtime on a different NPU(s), or use an existing runtime.
# Try creating a runtime on a different NPU(s), or use an existing runtime.
# Tracking issue: https://github.com/rebellions-inc/repo/issues/123

Copilot uses AI. Check for mistakes.
llm = LLM(**llm_args)
for _ in range(args.num_iter):
elapsed_time, outputs = run_llm(llm, prompts, sampling_params)
total_elapsed_time += elapsed_time
return total_elapsed_time


def calculate_avg_throughput_and_latency(elapsed_time: float, batch_size: int,
max_tokens: int,
num_iter: int) -> tuple[float, float]:
avg_throughput = (batch_size * max_tokens * num_iter) / elapsed_time
avg_latency = elapsed_time / num_iter
return avg_throughput, avg_latency


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--prompt_length", type=int, default=128)
parser.add_argument("-m", "--max_tokens", type=int, default=1)
parser.add_argument("-b", "--batch_size", type=int, default=1)
parser.add_argument("-n", "--num_iter", type=int, default=1)
args = parser.parse_args()

torch.manual_seed(42)

prompts = generate_prompts(args.prompt_length, args.batch_size)
elapsed_time = _worker(prompts, args)
43 changes: 29 additions & 14 deletions scripts/validation/compare_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

import os
from multiprocessing import get_context
from multiprocessing.queues import Queue as MPQueue
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from vllm import LLM, SamplingParams

# Set VOCAB_SIZE according to the model's tokenizer vocab size
# Set EPSILON to the acceptable logprob difference threshold
Expand All @@ -22,16 +27,6 @@
EPSILON = 1e-1 * 5
STEP = 1

llm_args = {
"model": "meta-llama/Llama-3.2-1B",
"max_model_len": 40 * 1024,
"block_size": 1024,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 128,
"max_num_seqs": 1,
"max_logprobs": VOCAB_SIZE,
}

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -40,12 +35,32 @@
]


def run_llm(llm, sampling_params, q):
def generate_llm_args(device: str):
llm_args = {
"model": "meta-llama/Llama-3.2-1B",
"max_model_len": 40 * 1024,
"enable_chunked_prefill": True,
"max_num_seqs": 1,
"max_logprobs": VOCAB_SIZE,
}
if device == "cpu":
llm_args["block_size"] = 128
llm_args["max_num_batched_tokens"] = 128
elif device == "rbln":
llm_args["block_size"] = 1024
llm_args["max_num_batched_tokens"] = 128
else:
raise ValueError(f"Unknown device: {device}")
return llm_args


def run_llm(llm: "LLM", sampling_params: "SamplingParams", q: MPQueue):
outputs = llm.generate(prompts, sampling_params)
q.put(outputs)


def _worker(device, q, llm_args):
def _worker(device: str, q: MPQueue):
llm_args = generate_llm_args(device)
if device == "cpu":
os.environ["VLLM_PLUGINS"] = "cpu"
os.environ["VLLM_USE_V1"] = "0"
Expand All @@ -71,12 +86,12 @@ def _worker(device, q, llm_args):
if __name__ == "__main__":
ctx = get_context("spawn")
q = ctx.Queue()
p1 = ctx.Process(target=_worker, args=("cpu", q, llm_args))
p1 = ctx.Process(target=_worker, args=("cpu", q))
p1.start()
cpu_outputs = q.get()
p1.join()

p2 = ctx.Process(target=_worker, args=("rbln", q, llm_args))
p2 = ctx.Process(target=_worker, args=("rbln", q))
p2.start()
rbln_outputs = q.get()
p2.join()
Expand Down
147 changes: 147 additions & 0 deletions scripts/validation/compare_logprobs_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2025 Rebellions Inc. All rights reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import urllib.request
from multiprocessing import get_context
from multiprocessing.queues import Queue as MPQueue
from typing import TYPE_CHECKING, Any

import torch
from transformers import AutoTokenizer

if TYPE_CHECKING:
from vllm import SamplingParams

MODEL_NAME = "meta-llama/Llama-3.2-1B"
PREFILL_CHUNK_SIZE = 128
VOCAB_SIZE = 128256
EPSILON = 1e-1 * 5


def get_wiki_prompt():
wiki_txt_url = "https://raw.githubusercontent.com/huggingface/optimum-neuron/refs/heads/main/benchmark/text-generation/performance/wiki.txt"
with urllib.request.urlopen(wiki_txt_url) as resp:
source_data = resp.read().decode("utf-8")
return source_data


def generate_llm_args(device: str, batch_size: int):
llm_args = {
"model": "meta-llama/Llama-3.2-1B",
"max_model_len": 40 * 1024,
"enable_chunked_prefill": True,
"max_num_seqs": batch_size,
"max_logprobs": VOCAB_SIZE,
}
if device == "cpu":
llm_args["block_size"] = 128 # 1024 is not working for long prompt
Copy link

Copilot AI Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment should explain why 1024 doesn't work for long prompts on CPU to help future maintainers understand the limitation.

Suggested change
llm_args["block_size"] = 128 # 1024 is not working for long prompt
llm_args["block_size"] = 128 # On CPU, using a block_size of 1024 can cause excessive memory usage or performance issues with long prompts, leading to failures. Reducing block_size to 128 avoids these issues.

Copilot uses AI. Check for mistakes.
elif device == "rbln":
llm_args["block_size"] = 1024
llm_args["max_num_batched_tokens"] = PREFILL_CHUNK_SIZE
return llm_args


def generate_prompts(prompt_length: int, batch_size: int) -> list[str]:
wiki_prompt = get_wiki_prompt()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokens = tokenizer(wiki_prompt, return_tensors="pt").input_ids[0]
assert len(tokens) > prompt_length * batch_size
prompts = []
for i in range(batch_size):
start_pos = i * prompt_length
# Leave 1 token for special token(bos) in the vllm
end_pos = (i + 1) * prompt_length - 1
prompt = tokenizer.decode(tokens[start_pos:end_pos])
prompts.append(prompt)
return prompts


def run_llm(llm, prompts: list[str], sampling_params: "SamplingParams",
q: MPQueue):
outputs = llm.generate(prompts, sampling_params=sampling_params)
q.put(outputs)


def _worker(device: str, prompts: list[str], q: MPQueue, args: Any):
llm_args = generate_llm_args(device, args.batch_size)
os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG"
if device == "cpu":
os.environ["VLLM_PLUGINS"] = "cpu"
os.environ["VLLM_USE_V1"] = "0"
elif device == "rbln":
os.environ.pop("VLLM_PLUGINS", None)
os.environ["RBLN_KERNEL_MODE"] = "triton"
os.environ["VLLM_USE_V1"] = "0"
os.environ["USE_VLLM_MODEL"] = "1"
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "0"
# 1 means disable using compile cache
else:
raise ValueError(f"Unknown device: {device}")

from vllm import LLM, SamplingParams
sampling_params = SamplingParams(
temperature=0.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.max_tokens,
logprobs=VOCAB_SIZE,
)
llm = LLM(**llm_args)
run_llm(llm, prompts, sampling_params, q)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-l", "--prompt_length", type=int, default=128)
parser.add_argument("-m", "--max_tokens", type=int, default=1)
parser.add_argument("-b", "--batch_size", type=int, default=1)
args = parser.parse_args()

torch.manual_seed(42)

prompts = generate_prompts(args.prompt_length, args.batch_size)

ctx = get_context("spawn")
q = ctx.Queue()
p1 = ctx.Process(target=_worker, args=("cpu", prompts, q, args))
p1.start()
cpu_outputs = q.get()
p1.join()

p2 = ctx.Process(target=_worker, args=("rbln", prompts, q, args))
p2.start()
rbln_outputs = q.get()
p2.join()

if p1.exitcode != 0 or p2.exitcode != 0:
raise SystemExit("One of the processes worked incorrectly.")

for cpu_output, rbln_output in zip(cpu_outputs, rbln_outputs):
print("=========" * 10)
cpu_logprobs = cpu_output.outputs[0].logprobs
rbln_logprobs = rbln_output.outputs[0].logprobs
num_outlier = 0
for cpu_lp_token_id, cpu_lp_score in cpu_logprobs[0].items():
cpu_logprob = cpu_lp_score.logprob
if cpu_lp_token_id not in rbln_logprobs[0]:
continue
rbln_logprob = rbln_logprobs[0].get(cpu_lp_token_id).logprob
if abs(cpu_logprob - rbln_logprob) >= EPSILON:
num_outlier += 1
print(f"Number of outliers: {num_outlier}")
print(f"Prompt: {cpu_output.prompt}")
print(f"Generated text (CPU): {cpu_output.outputs[0].text}")
print(f"Generated text (RBLN): {rbln_output.outputs[0].text}")
5 changes: 5 additions & 0 deletions vllm_rbln/rbln_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
RBLN_SAMPLER: bool = False
RBLN_ENABLE_WARM_UP: bool = False
RBLN_USE_VLLM_MODEL: bool = False
RBLN_FLASH_CAUSAL_ATTN: bool = True
RBLN_METRICS: bool = False

# extended environments
environment_variables = {
Expand All @@ -48,6 +50,9 @@
"RBLN_FLASH_CAUSAL_ATTN":
(lambda: os.environ.get("FLASH_CAUSAL_ATTN", "True").lower() in
("true", "1")),
"RBLN_METRICS":
(lambda: os.environ.get("VLLM_RBLN_METRICS", "False").lower() in
("true", "1")),
}


Expand Down
Loading
Loading