Skip to content

Commit 48ab052

Browse files
committed
feat: add vLLM Rust frontend
1 parent c72d6c8 commit 48ab052

33 files changed

Lines changed: 1719 additions & 925 deletions

AGENTS.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
mac env:
2+
python: .venv/bin/python
3+
4+
gpu env:
5+
ssh H200 server, then use conda activate parallax
6+
7+
8+
start a local server:
9+
python src/parallax/launch.py --model-path <MODEL_NAME> --log-level DEBUG
10+
11+
test the server:
12+
curl --location 'http://localhost:3000/v1/chat/completions' --header 'Content-Type: application/json' --data '{
13+
"max_tokens": 1024,
14+
"messages": [
15+
{
16+
"role": "user",
17+
"content": "hello"
18+
}
19+
],
20+
"stream": true
21+
}'
22+
23+
final check:
24+
run when user ask to git commit
25+
pre-commit run --all-files
26+
pytest

build_rust.sh

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/bin/bash
2+
# Build the vLLM Rust frontend binary and install it into Parallax's pip scripts dir.
3+
# Usage:
4+
# ./build_rust.sh [--debug]
5+
#
6+
# By default builds in release mode. Pass --debug for faster compile times
7+
# during development.
8+
9+
set -euo pipefail
10+
11+
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
12+
13+
show_help() {
14+
cat <<'EOF'
15+
Build the vLLM Rust frontend binary and install it into Parallax's pip scripts dir.
16+
17+
Usage:
18+
./build_rust.sh [--debug]
19+
20+
Environment:
21+
VLLM_REF vLLM git branch/tag to clone. Defaults to main.
22+
PARALLAX_PYTHON Python interpreter for the Parallax installation.
23+
EOF
24+
}
25+
26+
resolve_parallax_python() {
27+
if [[ -n "${PARALLAX_PYTHON:-}" ]]; then
28+
if [[ ! -x "$PARALLAX_PYTHON" ]]; then
29+
echo "PARALLAX_PYTHON is not executable: $PARALLAX_PYTHON" >&2
30+
exit 1
31+
fi
32+
echo "$PARALLAX_PYTHON"
33+
return
34+
fi
35+
36+
if [[ -n "${VIRTUAL_ENV:-}" && -x "$VIRTUAL_ENV/bin/python" ]]; then
37+
echo "$VIRTUAL_ENV/bin/python"
38+
return
39+
fi
40+
41+
if [[ -n "${CONDA_PREFIX:-}" && -x "$CONDA_PREFIX/bin/python" ]]; then
42+
echo "$CONDA_PREFIX/bin/python"
43+
return
44+
fi
45+
46+
if command -v python &>/dev/null; then
47+
command -v python
48+
return
49+
fi
50+
51+
echo "Unable to find Python for the Parallax installation." >&2
52+
echo "Activate the Parallax environment or set PARALLAX_PYTHON." >&2
53+
exit 1
54+
}
55+
56+
resolve_parallax_bin_dir() {
57+
local parallax_python
58+
parallax_python="$(resolve_parallax_python)"
59+
"$parallax_python" - <<'PY'
60+
import sysconfig
61+
62+
print(sysconfig.get_path("scripts"))
63+
PY
64+
}
65+
66+
if [[ "${1:-}" == "--help" || "${1:-}" == "-h" ]]; then
67+
show_help
68+
exit 0
69+
fi
70+
71+
if [[ $# -gt 1 || ( $# -eq 1 && "${1:-}" != "--debug" ) ]]; then
72+
show_help >&2
73+
exit 2
74+
fi
75+
76+
VLLM_REF="${VLLM_REF:-main}"
77+
CLONE_PARENT="$(mktemp -d "${TMPDIR:-/tmp}/parallax-vllm-rs.XXXXXX")"
78+
VLLM_CLONE_ROOT="$CLONE_PARENT/vllm"
79+
80+
cleanup_clone() {
81+
rm -rf "$CLONE_PARENT"
82+
}
83+
84+
trap cleanup_clone EXIT
85+
86+
if ! command -v git &>/dev/null; then
87+
echo "git not found; install git and rerun this script." >&2
88+
exit 1
89+
fi
90+
91+
echo "Cloning vLLM from https://github.com/vllm-project/vllm.git (ref: $VLLM_REF)"
92+
git clone --depth 1 --branch "$VLLM_REF" \
93+
https://github.com/vllm-project/vllm.git \
94+
"$VLLM_CLONE_ROOT"
95+
96+
RUST_DIR="$VLLM_CLONE_ROOT/rust"
97+
PARALLAX_SCRIPTS_DIR="$(resolve_parallax_bin_dir)"
98+
TARGET_PATH="$PARALLAX_SCRIPTS_DIR/vllm-rs"
99+
100+
if [[ ! -f "$RUST_DIR/Cargo.toml" || ! -f "$VLLM_CLONE_ROOT/rust-toolchain.toml" ]]; then
101+
echo "Cloned repository does not contain the expected vLLM Rust frontend sources." >&2
102+
exit 1
103+
fi
104+
105+
# Read the required toolchain from rust-toolchain.toml.
106+
TOOLCHAIN=$(grep '^channel' "$VLLM_CLONE_ROOT/rust-toolchain.toml" | sed 's/.*= *"\(.*\)"/\1/')
107+
108+
if [[ -z "$TOOLCHAIN" ]]; then
109+
echo "Unable to read Rust toolchain from $VLLM_CLONE_ROOT/rust-toolchain.toml" >&2
110+
exit 1
111+
fi
112+
113+
# Ensure rustup and the required toolchain are available.
114+
if ! command -v rustup &>/dev/null; then
115+
echo "rustup not found, installing..."
116+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain none
117+
# shellcheck disable=SC1091
118+
source "$HOME/.cargo/env"
119+
fi
120+
121+
if ! rustup run "$TOOLCHAIN" rustc --version &>/dev/null; then
122+
echo "Installing Rust toolchain: $TOOLCHAIN"
123+
rustup toolchain install "$TOOLCHAIN"
124+
fi
125+
126+
if [[ "${1:-}" == "--debug" ]]; then
127+
PROFILE_ARGS=()
128+
PROFILE_DIR="debug"
129+
else
130+
PROFILE_ARGS=(--release)
131+
PROFILE_DIR="release"
132+
fi
133+
134+
cargo +"$TOOLCHAIN" build "${PROFILE_ARGS[@]}" \
135+
--manifest-path "$RUST_DIR/Cargo.toml" \
136+
--bin vllm-rs \
137+
--features native-tls-vendored
138+
139+
mkdir -p "$(dirname "$TARGET_PATH")"
140+
cp "$RUST_DIR/target/$PROFILE_DIR/vllm-rs" "$TARGET_PATH"
141+
chmod +x "$TARGET_PATH"
142+
echo "Installed vllm-rs to $TARGET_PATH"

scripts/generate.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
import time
2626

2727
import mlx.core as mx
28+
from mlx_lm.server import convert_chat, process_message_content
2829

2930
from parallax.server.cache_manager import CacheManager
3031
from parallax.server.request import InitialRequest
3132
from parallax.server.sampling.sampler import SamplingBatchInfo
3233
from parallax.server.sampling.sampling_params import SamplingParams
34+
from parallax.server.scheduler import _normalize_token_ids
3335
from parallax.server.shard_loader import MLXModelLoader
3436
from parallax.utils.utils import create_causal_mask, get_layer_types
3537

@@ -44,6 +46,40 @@ def print_rank(message):
4446
print(f"[Rank {tp_rank}] {message}")
4547

4648

49+
def get_eos_token_ids(config, tokenizer):
50+
eos_token_id = config.get("eos_token_id")
51+
tokenizer_eos_token_id = getattr(tokenizer, "eos_token_id", None)
52+
if eos_token_id is None:
53+
eos_token_id = tokenizer_eos_token_id
54+
55+
eos_token_ids = _normalize_token_ids(eos_token_id)
56+
eos_token_ids.update(_normalize_token_ids(tokenizer_eos_token_id))
57+
return eos_token_ids
58+
59+
60+
def build_prompt(messages, tokenizer):
61+
if tokenizer.chat_template:
62+
process_message_content(messages)
63+
prompt_tokens = tokenizer.apply_chat_template(
64+
messages,
65+
None,
66+
tokenize=True,
67+
add_generation_prompt=True,
68+
return_dict=False,
69+
)
70+
full_prompt = tokenizer.apply_chat_template(
71+
messages,
72+
None,
73+
tokenize=False,
74+
add_generation_prompt=True,
75+
return_dict=False,
76+
)
77+
else:
78+
full_prompt = convert_chat(messages, None)
79+
prompt_tokens = tokenizer.encode(full_prompt)
80+
return full_prompt, prompt_tokens
81+
82+
4783
def main():
4884
parser = argparse.ArgumentParser(description="Simple offline inference script")
4985
parser.add_argument(
@@ -76,6 +112,8 @@ def main():
76112
# 2. Initialize CacheManager
77113
num_layers = config.get("num_hidden_layers")
78114
num_kv_heads = config.get("num_key_value_heads")
115+
if num_kv_heads is None:
116+
num_kv_heads = config.get("num_attention_groups")
79117
head_dim = config.get("head_dim") or config.get("hidden_size") // config.get(
80118
"num_attention_heads"
81119
)
@@ -88,6 +126,18 @@ def main():
88126

89127
v_head_dim = config.get("v_head_dim")
90128
layer_types = get_layer_types(config, 0, num_layers)
129+
linear_key_head_dim = config.get("linear_key_head_dim")
130+
linear_value_head_dim = config.get("linear_value_head_dim")
131+
linear_conv_kernel_dim = config.get("linear_conv_kernel_dim")
132+
linear_num_key_heads = config.get("linear_num_key_heads")
133+
linear_num_value_heads = config.get("linear_num_value_heads")
134+
key_dim, value_dim, conv_dim = None, None, None
135+
if linear_key_head_dim is not None and linear_num_key_heads is not None:
136+
key_dim = linear_key_head_dim * linear_num_key_heads
137+
if linear_value_head_dim is not None and linear_num_value_heads is not None:
138+
value_dim = linear_value_head_dim * linear_num_value_heads
139+
if key_dim is not None and value_dim is not None:
140+
conv_dim = key_dim * 2 + value_dim
91141

92142
cache_manager = CacheManager(
93143
num_layers=num_layers,
@@ -98,19 +148,17 @@ def main():
98148
cache_memory_fraction=0.1,
99149
head_dim_v=v_head_dim,
100150
layer_types=layer_types,
151+
conv_dim=conv_dim,
152+
conv_kernel_size=linear_conv_kernel_dim,
153+
linear_k_dim=linear_key_head_dim,
154+
linear_v_dim=linear_value_head_dim,
155+
linear_num_k_heads=linear_num_key_heads,
156+
linear_num_v_heads=linear_num_value_heads,
101157
)
102158

103159
# 3. Tokenize and Create Request
104160
messages = [{"role": "user", "content": args.prompt}]
105-
106-
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
107-
full_prompt = tokenizer.apply_chat_template(
108-
messages, tokenize=False, add_generation_prompt=True
109-
)
110-
else:
111-
full_prompt = args.prompt
112-
113-
prompt_tokens = tokenizer.encode(full_prompt)
161+
full_prompt, prompt_tokens = build_prompt(messages, tokenizer)
114162
sampling_params = SamplingParams(temperature=args.temp, top_k=args.topk)
115163
request = InitialRequest(
116164
prompt=full_prompt,
@@ -119,22 +167,9 @@ def main():
119167
max_new_tokens=args.max_tokens,
120168
)
121169

122-
eos_token_ids = []
123-
if tokenizer.eos_token_id is not None:
124-
if isinstance(tokenizer.eos_token_id, list):
125-
eos_token_ids.extend(tokenizer.eos_token_id)
126-
else:
127-
eos_token_ids.append(tokenizer.eos_token_id)
128-
config_eos = config.get("eos_token_id")
129-
if config_eos is not None:
130-
if isinstance(config_eos, list):
131-
for e in config_eos:
132-
if e not in eos_token_ids:
133-
eos_token_ids.append(e)
134-
elif config_eos not in eos_token_ids:
135-
eos_token_ids.append(config_eos)
136-
137-
eos_token_ids = set(eos_token_ids)
170+
eos_token_ids = get_eos_token_ids(config, tokenizer)
171+
if not eos_token_ids:
172+
raise ValueError("EOS token ID must be set for generation.")
138173

139174
# 4. Prefill
140175
print_rank(f"Full prompt:\n {full_prompt}")
@@ -151,6 +186,9 @@ def main():
151186
input_ids = mx.array([request.input_ids])
152187
block_table = mx.array([cache_manager.get_block_table(request.request_id)], dtype=mx.int32)
153188
context_lengths = mx.array([request.prompt_len], dtype=mx.int32)
189+
state_slot_mapping = None
190+
if cache_manager.needs_slots:
191+
state_slot_mapping = mx.array([cache_manager.get_slot(request.request_id)], dtype=mx.int32)
154192

155193
block_size = cache_manager.block_size
156194
slot_mapping = []
@@ -172,21 +210,27 @@ def main():
172210
block_tables=block_table,
173211
context_lengths=context_lengths,
174212
slot_mapping=slot_mapping,
213+
state_slot_mapping=state_slot_mapping,
175214
)
176215

177216
sampling_info = SamplingBatchInfo.from_reqs([request])
178217

179218
next_token_id = model.logits_to_tokens(logits, context_lengths, sampling_info)
180219

181220
token_id = int(next_token_id[0])
182-
request.commit_new_token(token_id)
221+
is_finished = token_id in eos_token_ids
222+
if not is_finished:
223+
request.commit_new_token(token_id)
183224

184225
prefill_time = time.perf_counter() - prefill_start
185226
print_rank(f"Token 1 (Prefill) time: {prefill_time * 1000:.2f} ms")
186227

187228
# 5. Decode Loop
188229
total_decode_time = 0
189230
for i in range(args.max_tokens - 1):
231+
if is_finished:
232+
break
233+
190234
decode_step_start = time.perf_counter()
191235

192236
success = cache_manager.append_slot(request.request_id)
@@ -204,12 +248,14 @@ def main():
204248
mask=None,
205249
block_tables=block_table,
206250
context_lengths=context_lengths,
251+
state_slot_mapping=state_slot_mapping,
207252
)
208253

209254
next_token_id = model.logits_to_tokens(logits, mx.array([1]), sampling_info)
210255

211256
token_id = int(next_token_id[0])
212-
if token_id in eos_token_ids:
257+
is_finished = token_id in eos_token_ids
258+
if is_finished:
213259
break
214260
request.commit_new_token(token_id)
215261

0 commit comments

Comments
 (0)