Skip to content

Commit d8e6e42

Browse files
janhilgardclaude
andcommitted
feat: Add --gpu-memory-utilization, repetition_penalty, fix OOM on large models
- Add --gpu-memory-utilization CLI flag (default 0.90) to control Metal soft allocation limit and emergency cache clear threshold - Fix OOM SIGKILL on large models (200GB+): clear traceback references between strict=True/False retry in load_model_with_fallback() to free memory from the failed first load before retrying - Add repetition_penalty/frequency_penalty/presence_penalty support for chat and completion endpoints Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bf486e6 commit d8e6e42

File tree

7 files changed

+126
-27
lines changed

7 files changed

+126
-27
lines changed

vllm_mlx/api/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ class ChatCompletionRequest(BaseModel):
172172
# MLLM-specific parameters
173173
video_fps: float | None = None
174174
video_max_frames: int | None = None
175+
# Sampling penalties
176+
repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes)
177+
frequency_penalty: float | None = None # OpenAI style (0-2)
178+
presence_penalty: float | None = None # OpenAI style (0-2)
175179
# Request timeout in seconds (None = use server default)
176180
timeout: float | None = None
177181

@@ -235,6 +239,10 @@ class CompletionRequest(BaseModel):
235239
max_tokens: int | None = None
236240
stream: bool = False
237241
stop: list[str] | None = None
242+
# Sampling penalties
243+
repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes)
244+
frequency_penalty: float | None = None # OpenAI style (0-2)
245+
presence_penalty: float | None = None # OpenAI style (0-2)
238246
# Request timeout in seconds (None = use server default)
239247
timeout: float | None = None
240248

vllm_mlx/cli.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ def serve_command(args):
3737
print("Example: --enable-auto-tool-choice --tool-call-parser mistral")
3838
sys.exit(1)
3939

40+
# Validate gpu-memory-utilization range
41+
if not (0.0 < args.gpu_memory_utilization <= 1.0):
42+
print(
43+
"Error: --gpu-memory-utilization must be between 0.0 (exclusive) and 1.0 (inclusive)"
44+
)
45+
sys.exit(1)
46+
4047
# Configure server security settings
4148
server._api_key = args.api_key
4249
server._default_timeout = args.timeout
@@ -186,6 +193,7 @@ def serve_command(args):
186193
scheduler_config=scheduler_config,
187194
stream_interval=args.stream_interval if args.continuous_batching else 1,
188195
max_tokens=args.max_tokens,
196+
gpu_memory_utilization=args.gpu_memory_utilization,
189197
)
190198

191199
# Start server
@@ -680,6 +688,14 @@ def main():
680688
action="store_true",
681689
help="Enable continuous batching for multiple concurrent users (slower for single user)",
682690
)
691+
serve_parser.add_argument(
692+
"--gpu-memory-utilization",
693+
type=float,
694+
default=0.90,
695+
help="Fraction of device memory for Metal allocation limit and emergency "
696+
"cache clear threshold (0.0-1.0, default: 0.90). Increase to 0.95 for "
697+
"large models (200GB+) that need more memory headroom.",
698+
)
683699
# Paged cache options (experimental)
684700
serve_parser.add_argument(
685701
"--use-paged-cache",

vllm_mlx/engine/batched.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def __init__(
137137
scheduler_config: Any | None = None,
138138
stream_interval: int = 1,
139139
force_mllm: bool = False,
140+
gpu_memory_utilization: float = 0.90,
140141
):
141142
"""
142143
Initialize the batched engine.
@@ -147,11 +148,14 @@ def __init__(
147148
scheduler_config: Optional scheduler configuration
148149
stream_interval: Tokens to batch before streaming (1=every token)
149150
force_mllm: Force loading as MLLM even if not auto-detected
151+
gpu_memory_utilization: Fraction of device memory for Metal allocation
152+
limit and emergency threshold (0.0-1.0, default 0.90)
150153
"""
151154
self._model_name = model_name
152155
self._trust_remote_code = trust_remote_code
153156
self._scheduler_config = scheduler_config
154157
self._stream_interval = stream_interval
158+
self._gpu_memory_utilization = gpu_memory_utilization
155159
self._is_mllm = force_mllm or is_mllm_model(model_name)
156160

157161
self._model = None
@@ -283,13 +287,14 @@ async def _start_llm(self) -> None:
283287
device_info.get("memory_size", 0),
284288
)
285289
if max_recommended > 0:
286-
soft_limit = int(max_recommended * 0.95)
290+
soft_limit = int(max_recommended * self._gpu_memory_utilization)
287291
mx.set_memory_limit(soft_limit)
288292
mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB
293+
pct = self._gpu_memory_utilization * 100
289294
logger.info(
290295
f"Metal memory limits set: "
291296
f"allocation_limit={soft_limit / 1e9:.1f}GB "
292-
f"(90% of {max_recommended / 1e9:.1f}GB), "
297+
f"({pct:.0f}% of {max_recommended / 1e9:.1f}GB), "
293298
f"cache_limit=32GB"
294299
)
295300
except Exception as e:
@@ -301,6 +306,7 @@ async def _start_llm(self) -> None:
301306
model_name=self._model_name,
302307
scheduler_config=scheduler_config,
303308
stream_interval=self._stream_interval,
309+
gpu_memory_utilization=self._gpu_memory_utilization,
304310
)
305311

306312
# Create async engine

vllm_mlx/engine_core.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class EngineConfig:
3636
scheduler_config: Optional[SchedulerConfig] = None
3737
step_interval: float = 0.001 # 1ms between steps
3838
stream_interval: int = 1 # Tokens to batch before streaming (1=every token)
39+
gpu_memory_utilization: float = 0.90 # Fraction of device memory for allocation
3940

4041

4142
class EngineCore:
@@ -150,8 +151,17 @@ async def _engine_loop(self) -> None:
150151
stream_interval = self.config.stream_interval
151152
use_simple_streaming = stream_interval == 1
152153

153-
# Emergency memory pressure threshold (245GB — raised for large models)
154-
_memory_pressure_threshold = 245 * 1024 * 1024 * 1024
154+
# Emergency memory pressure threshold — dynamic based on gpu_memory_utilization
155+
_gpu_mem_util = self.config.gpu_memory_utilization
156+
try:
157+
_device_mem = mx.device_info().get(
158+
"memory_size", 200 * 1024 * 1024 * 1024
159+
)
160+
_memory_pressure_threshold = int(
161+
_device_mem * min(_gpu_mem_util + 0.05, 0.99)
162+
)
163+
except Exception:
164+
_memory_pressure_threshold = 200 * 1024 * 1024 * 1024
155165
_memory_check_interval = 64
156166

157167
while self._running:

vllm_mlx/scheduler.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import mlx.core as mx
2121
from mlx_lm.generate import BatchGenerator
22-
from mlx_lm.sample_utils import make_sampler
22+
from mlx_lm.sample_utils import make_logits_processors, make_sampler
2323

2424
from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig
2525
from .paged_cache import PagedCacheManager
@@ -403,7 +403,7 @@ def _chunked_next(self=batch_gen): # noqa: C901
403403

404404
if not is_cached:
405405
padded = _left_pad_prompts(inputs_raw, max_length=max_length)
406-
prompt_cache = _make_cache(self.model, padding)
406+
prompt_cache = _make_cache(self.model, padding, self.max_kv_size)
407407
else:
408408
last_inputs = mx.array([p[-1:] for p in inputs_raw])
409409
padded = _right_pad_prompts(inputs_raw, max_length=max_length)
@@ -644,6 +644,10 @@ def _mtp_step(
644644

645645
# --- Apply logits processors + sample primary ---
646646
if any(logits_processors):
647+
logger.debug(
648+
f"[logits_proc] applying {sum(len(lp) for lp in logits_processors)} "
649+
f"processors to batch_size={batch_size}"
650+
)
647651
processed_logits = []
648652
for e in range(batch_size):
649653
sample_logits = logits[e : e + 1]
@@ -1760,15 +1764,30 @@ def _schedule_waiting(self) -> List[Request]:
17601764
request.remaining_tokens = request.prompt_token_ids
17611765
tokens_to_process = request.prompt_token_ids
17621766

1767+
# Build per-request logits_processors from repetition_penalty
1768+
rep_penalty = request.sampling_params.repetition_penalty
1769+
lp = None
1770+
if rep_penalty and rep_penalty != 1.0:
1771+
lp = make_logits_processors(repetition_penalty=rep_penalty)
1772+
logger.info(
1773+
f"[rep_penalty] request={request.request_id[:12]} "
1774+
f"penalty={rep_penalty} processors={len(lp)}"
1775+
)
1776+
17631777
# Insert into BatchGenerator with optional cache.
17641778
# Wrap in try/except: if cache shapes are incompatible
17651779
# (e.g. stale entry after BatchGenerator recreation),
17661780
# fall back to no-cache insert instead of crashing.
1781+
insert_kwargs = {
1782+
"max_tokens": [request.sampling_params.max_tokens],
1783+
"caches": [cache_to_use] if cache_to_use else None,
1784+
}
1785+
if lp:
1786+
insert_kwargs["logits_processors"] = [lp]
17671787
try:
17681788
uids = self.batch_generator.insert(
17691789
[tokens_to_process],
1770-
max_tokens=[request.sampling_params.max_tokens],
1771-
caches=[cache_to_use] if cache_to_use else None,
1790+
**insert_kwargs,
17721791
)
17731792
except Exception as e:
17741793
if cache_to_use is not None:
@@ -1781,10 +1800,10 @@ def _schedule_waiting(self) -> List[Request]:
17811800
request.cached_tokens = 0
17821801
request.remaining_tokens = request.prompt_token_ids
17831802
tokens_to_process = request.prompt_token_ids
1803+
insert_kwargs["caches"] = None
17841804
uids = self.batch_generator.insert(
17851805
[tokens_to_process],
1786-
max_tokens=[request.sampling_params.max_tokens],
1787-
caches=None,
1806+
**insert_kwargs,
17881807
)
17891808
else:
17901809
raise
@@ -1805,11 +1824,16 @@ def _schedule_waiting(self) -> List[Request]:
18051824
else ""
18061825
)
18071826
tokens_to_prefill = len(tokens_to_process)
1827+
rep_info = (
1828+
f" rep_penalty={rep_penalty}"
1829+
if rep_penalty and rep_penalty != 1.0
1830+
else ""
1831+
)
18081832
logger.info(
18091833
f"[schedule] request={request.request_id[:12]} uid={uid} "
18101834
f"prompt_tokens={request.num_prompt_tokens} "
18111835
f"tokens_to_prefill={tokens_to_prefill}{cache_info} "
1812-
f"max_tokens={request.sampling_params.max_tokens} "
1836+
f"max_tokens={request.sampling_params.max_tokens}{rep_info} "
18131837
f"running={len(self.running)} waiting={len(self.waiting)}"
18141838
)
18151839

vllm_mlx/server.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def load_model(
517517
stream_interval: int = 1,
518518
max_tokens: int = 32768,
519519
force_mllm: bool = False,
520+
gpu_memory_utilization: float = 0.90,
520521
):
521522
"""
522523
Load a model (auto-detects MLLM vs LLM).
@@ -546,6 +547,7 @@ def load_model(
546547
scheduler_config=scheduler_config,
547548
stream_interval=stream_interval,
548549
force_mllm=force_mllm,
550+
gpu_memory_utilization=gpu_memory_utilization,
549551
)
550552
# BatchedEngine will be started in lifespan (uvicorn's event loop)
551553
# Just log for now
@@ -1231,10 +1233,22 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
12311233
f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}"
12321234
)
12331235

1236+
# Resolve repetition penalty for completions
1237+
comp_rep_penalty = request.repetition_penalty
1238+
if comp_rep_penalty is None and request.frequency_penalty:
1239+
comp_rep_penalty = 1.0 + request.frequency_penalty
1240+
if comp_rep_penalty is None and request.presence_penalty:
1241+
comp_rep_penalty = 1.0 + request.presence_penalty
1242+
12341243
if request.stream:
12351244
return StreamingResponse(
12361245
_disconnect_guard(
1237-
stream_completion(engine, prompts[0], request),
1246+
stream_completion(
1247+
engine,
1248+
prompts[0],
1249+
request,
1250+
repetition_penalty=comp_rep_penalty,
1251+
),
12381252
raw_request,
12391253
),
12401254
media_type="text/event-stream",
@@ -1248,14 +1262,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
12481262
total_prompt_tokens = 0
12491263

12501264
for i, prompt in enumerate(prompts):
1265+
gen_kwargs = {
1266+
"max_tokens": request.max_tokens or _default_max_tokens,
1267+
"temperature": _resolve_temperature(request.temperature),
1268+
"top_p": _resolve_top_p(request.top_p),
1269+
"stop": request.stop,
1270+
}
1271+
if comp_rep_penalty is not None:
1272+
gen_kwargs["repetition_penalty"] = comp_rep_penalty
12511273
output = await _wait_with_disconnect(
1252-
engine.generate(
1253-
prompt=prompt,
1254-
max_tokens=request.max_tokens or _default_max_tokens,
1255-
temperature=_resolve_temperature(request.temperature),
1256-
top_p=_resolve_top_p(request.top_p),
1257-
stop=request.stop,
1258-
),
1274+
engine.generate(prompt=prompt, **gen_kwargs),
12591275
raw_request,
12601276
timeout=timeout,
12611277
)
@@ -1387,12 +1403,21 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
13871403
# Inject JSON instruction into messages
13881404
messages = _inject_json_instruction(messages, json_instruction)
13891405

1406+
# Resolve repetition penalty: explicit > frequency_penalty > presence_penalty
1407+
rep_penalty = request.repetition_penalty
1408+
if rep_penalty is None and request.frequency_penalty:
1409+
rep_penalty = 1.0 + request.frequency_penalty
1410+
if rep_penalty is None and request.presence_penalty:
1411+
rep_penalty = 1.0 + request.presence_penalty
1412+
13901413
# Prepare kwargs
13911414
chat_kwargs = {
13921415
"max_tokens": request.max_tokens or _default_max_tokens,
13931416
"temperature": _resolve_temperature(request.temperature),
13941417
"top_p": _resolve_top_p(request.top_p),
13951418
}
1419+
if rep_penalty is not None:
1420+
chat_kwargs["repetition_penalty"] = rep_penalty
13961421

13971422
# Add multimodal content
13981423
if has_media:
@@ -1862,15 +1887,18 @@ async def stream_completion(
18621887
engine: BaseEngine,
18631888
prompt: str,
18641889
request: CompletionRequest,
1890+
repetition_penalty: float | None = None,
18651891
) -> AsyncIterator[str]:
18661892
"""Stream completion response."""
1867-
async for output in engine.stream_generate(
1868-
prompt=prompt,
1869-
max_tokens=request.max_tokens or _default_max_tokens,
1870-
temperature=_resolve_temperature(request.temperature),
1871-
top_p=_resolve_top_p(request.top_p),
1872-
stop=request.stop,
1873-
):
1893+
gen_kwargs = {
1894+
"max_tokens": request.max_tokens or _default_max_tokens,
1895+
"temperature": _resolve_temperature(request.temperature),
1896+
"top_p": _resolve_top_p(request.top_p),
1897+
"stop": request.stop,
1898+
}
1899+
if repetition_penalty is not None:
1900+
gen_kwargs["repetition_penalty"] = repetition_penalty
1901+
async for output in engine.stream_generate(prompt=prompt, **gen_kwargs):
18741902
data = {
18751903
"id": f"cmpl-{uuid.uuid4().hex[:8]}",
18761904
"object": "text_completion",

vllm_mlx/utils/tokenizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,14 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None):
5959
return _load_with_tokenizer_fallback(model_name)
6060
# Fallback for models with extra weights (e.g., MTP layers)
6161
elif "parameters not in model" in str(e):
62-
logger.warning(f"Extra parameters found (e.g., MTP weights), retrying with strict=False: {e}")
62+
logger.warning(f"Extra parameters found (e.g., MTP weights), retrying with strict=False")
63+
# Clear traceback references to free memory from the failed first load.
64+
# Without this, large models (200GB+) cause OOM during retry because
65+
# the traceback holds references to the first load's weight tensors.
66+
e.__traceback__ = None
67+
del e
68+
import gc
69+
gc.collect()
6370
from mlx_lm.utils import _download, load_model, load_tokenizer
6471
model_path = _download(model_name)
6572
model, config = load_model(model_path, strict=False)

0 commit comments

Comments
 (0)