Skip to content

thehighnotes/vllm-jetson-orin

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

vLLM + Marlin Tensor Core Inference on Jetson Orin

Pre-built vLLM wheel with Marlin GPTQ kernels compiled for SM 8.7 (Jetson Orin family), plus server configuration, benchmark scripts, and an optional scheduler fast-path patch.

The Problem

Pre-built vLLM wheels ship Marlin kernels for SM 8.0, 8.6, 8.9, and 9.0 — but not SM 8.7 (Jetson AGX Orin, Orin Nano, Orin NX). Without SM 8.7 support, --quantization gptq_marlin falls back to generic CUDA core kernels, leaving Orin's tensor cores idle. The result: prefill is 8x slower than it should be.

The Solution

Build vLLM from source with TORCH_CUDA_ARCH_LIST="8.7" to compile Marlin kernels targeting Orin's SM 8.7 tensor cores. This unlocks fused INT4 dequantization + FP16 tensor core matrix multiply for any GPTQ-Int4 quantized model.

Or just install our pre-built wheel (saves a 75-minute build):

pip install https://huggingface.co/thehighnotes/vllm-jetson-orin/resolve/main/vllm-0.17.0+cu126-cp310-cp310-linux_aarch64.whl

Benchmarks

Tested on Jetson AGX Orin 64GB with Qwen3.5-35B-A3B-GPTQ-Int4 (35B parameters, 3.5B active via MoE).

Prefill Throughput

Engine Prefill (tok/s) Speedup
llama.cpp Q4_K_M 523 1.0x
vLLM GPTQ (no Marlin) 241 0.5x
vLLM GPTQ Marlin 2,001 3.8x

Decode Throughput (streaming, first-to-last token)

Context Length vLLM Marlin llama.cpp Speedup
~38 tokens 31.4 tok/s 22.5 tok/s +40%
~4,000 tokens 30.9 tok/s 22.5 tok/s +37%
~20,000 tokens 29.2 tok/s 22.5 tok/s +30%

End-to-End (20k context + 200 output tokens)

Engine Total Time Speedup
llama.cpp 47s 1.0x
vLLM Marlin 17s 2.8x

Note: vLLM's built-in throughput logger averages over 10-second windows mixing prefill, decode, and idle time. Always measure decode rate via streaming (first-to-last token) for accurate numbers. See scripts/benchmark_decode.py.

Prerequisites

Component Required Version
Device Jetson AGX Orin, Orin Nano, or Orin NX (SM 8.7)
JetPack 6.x (tested on 6.2.1)
CUDA 12.6
Python 3.10

Important constraints:

  • NumPy must be 1.x (PyTorch on Jetson is compiled against NumPy 1.x)
  • This wheel is not compatible with x86_64 systems or non-Orin Jetsons (different SM)
  • Unified memory: ensure at least ~20 GB free for model loading (varies by model size)

Quick Start

1. Create a virtual environment

python3 -m venv ~/vllm-venv
source ~/vllm-venv/bin/activate
pip install --upgrade pip

2. Install PyTorch for Jetson

pip install torch --index-url https://pypi.jetson-ai-lab.io/jp6/cu126

3. Install the vLLM wheel

pip install https://huggingface.co/thehighnotes/vllm-jetson-orin/resolve/main/vllm-0.17.0+cu126-cp310-cp310-linux_aarch64.whl

4. Install Triton (required for torch.compile)

pip install triton
export CUDA_HOME=/usr/local/cuda
export CPATH=/usr/local/cuda/include:${CPATH:-}

Triton compiles kernels on first run and needs cuda.h accessible via CPATH.

5. Set environment variables

export LD_LIBRARY_PATH="$(python -c 'import nvidia.cu12; print(nvidia.cu12.__path__[0])')/lib:${LD_LIBRARY_PATH:-}"
export LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/targets/aarch64-linux/lib:${LD_LIBRARY_PATH}"
export HF_HOME="${HF_HOME:-$HOME/.cache/huggingface}"

6. Launch

Using the included start script (defaults to Qwen3.5-35B-A3B-GPTQ-Int4):

./scripts/start_server.sh

Or launch directly with any GPTQ-Int4 model:

VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING=0 \
vllm serve Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 \
    --host 0.0.0.0 --port 8000 \
    --quantization gptq_marlin \
    --dtype half \
    --gpu-memory-utilization 0.8 \
    --max-model-len 4096 \
    --max-num-batched-tokens 4096 \
    --max-num-seqs 1 \
    --no-enable-log-requests

Startup takes 3-5 minutes (model load + torch.compile + CUDA graph capture).

7. Test

curl http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{"model": "Qwen/Qwen3.5-35B-A3B-GPTQ-Int4", "messages": [{"role": "user", "content": "Hello!"}]}'

Server Configuration

Key Flags Explained

Flag Purpose
--quantization gptq_marlin Use Marlin tensor core kernels for INT4 GPTQ models
--dtype half FP16 — Orin tensor cores prefer fp16 over bf16. Models stored as bf16 will be cast automatically (harmless warning)
--max-num-seqs 1 Single sequence. Enables CUDA graph capture. Required for some models to avoid causal_conv1d_update assertion failures
--max-num-batched-tokens N Set equal to --max-model-len to avoid chunked prefill overhead
--gpu-memory-utilization Fraction of GPU memory for KV cache. Orin has unified memory shared with CPU — lower this if other processes need GPU memory
--no-enable-log-requests Reduce log noise (renamed from --disable-log-requests in vLLM 0.17)

Environment Variables

Variable Purpose
VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING=0 Disables MoE activation chunking. Unnecessary for batch=1 decode and avoids overhead. Set in both the start script and systemd template.
VLLM_ATTENTION_BACKEND Override attention backend selection (see Attention Backend section). Not set by default.

systemd Service

A template service file is provided in service/vllm-server.service. To install:

# Edit the template: search for CHANGEME and replace all placeholders
sudo cp service/vllm-server.service /etc/systemd/system/
sudo systemctl daemon-reload
sudo systemctl start vllm-server

Critical: LD_LIBRARY_PATH in the service file must include both pip nvidia libs AND system CUDA paths. Without system paths (/usr/local/cuda/lib64), cuBLAS cannot create handles and you'll get CUBLAS_STATUS_ALLOC_FAILED during CUDA graph capture. This is because systemd replaces LD_LIBRARY_PATH entirely instead of appending to the system default.

Launch Script

The start script defaults to Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 but accepts any model and supports flag overrides:

# Default model (Qwen3.5-35B-A3B-GPTQ-Int4)
./scripts/start_server.sh

# Custom model
./scripts/start_server.sh TheBloke/Llama-2-7B-GPTQ --gpu-memory-utilization 0.8

# Override quantization for non-Marlin models
./scripts/start_server.sh some/model --quantization gptq --dtype float16

Benchmarking

Two benchmark scripts are included:

# Streaming benchmark (accurate decode throughput, reports TTFT)
python scripts/benchmark_decode.py --context-lengths 100,4000,20000 --max-tokens 200

# Quick non-streaming benchmark (auto-detects model from server)
./scripts/benchmark.sh
./scripts/benchmark.sh Qwen/Qwen3.5-35B-A3B-GPTQ-Int4 200 4000

Attention Backend

What's Used

vLLM selects attention backends by priority based on compute capability. For SM 8.7 (Orin), the priority is:

  1. FLASH_ATTN — vLLM's bundled flash attention (vllm_flash_attn/), always available on CUDA
  2. FLASHINFER — FlashInfer (JIT-compiled), installed as a dependency but never reached

Since the bundled flash attention is always available, FlashInfer is installed but not used by default.

The bundled vllm_flash_attn kernels in this wheel include native SM 8.7 SASS code (compiled via TORCH_CUDA_ARCH_LIST="8.7"). Both FA2 (_vllm_fa2_C.abi3.so) and FA3 (_vllm_fa3_C.abi3.so) contain SM 8.0, 8.6, 8.7, 8.9, and 9.0 targets.

On Blackwell (SM 10.0), the priority is reversed — FlashInfer is preferred. This only affects SM 8.x.

FlashInfer as an Alternative (Untested)

FlashInfer 0.6.4 is installed as a vLLM dependency and can be forced via:

export VLLM_ATTENTION_BACKEND=FLASHINFER

Community reports suggest FlashInfer may yield ~25% higher throughput on SM 8.7 for both prefill and decode, due to differences in kernel scheduling and memory access patterns. A FlashInfer PR fixed an issue where SM 8.7 was ignored by the build system — this fix may or may not be included in the pip-installed FlashInfer 0.6.4 (which uses JIT compilation rather than pre-built kernels).

This is an untested optimization opportunity. If you try it, we'd love to hear your results — please open an issue with before/after benchmark numbers.

Scheduler Fast Path Patch (Optional)

An optional patch is included that adds a fast decode path to vLLM's V1 scheduler. It reduces per-token Python overhead for single-user batch=1 decode by short-circuiting the full ~590-line schedule() method when conditions are trivial.

How It Works

The patch adds two methods:

  • _try_fast_decode_path() — Handles the schedule() call. When there's exactly 1 running request, no waiting requests, no spec decode, no LoRA, and no encoder inputs, it builds the SchedulerOutput in ~40 lines instead of 590.
  • _try_fast_update_from_output() — Handles the update_from_output() call. Same conditions: fast-paths the output processing for single-request decode.

Both methods return None when conditions aren't met, falling through to the full path. This means:

  • First token (prefill) uses the full path
  • Request finishing uses the full path
  • Multiple concurrent requests use the full path
  • Any advanced feature (spec decode, LoRA, etc.) uses the full path

Applying the Patch

The patch targets vLLM 0.17.0 (commit bf397ad).

# From a vLLM 0.17.0 source checkout:
git apply patches/scheduler_fast_path.patch

# Or apply directly to an installed package:
cd /path/to/venv/lib/python3.10/site-packages
patch -p1 < /path/to/patches/scheduler_fast_path.patch
# Then delete the cached bytecode:
find vllm/v1/core/sched/__pycache__ -name 'scheduler*.pyc' -delete

The pre-built wheel already includes this patch.

Build From Source

If you prefer to build yourself (or need a different vLLM version):

Prerequisites

sudo apt install -y cmake ninja-build gcc g++

Requires CUDA toolkit (nvcc) from JetPack, git, and ~5 GB disk space for source + build artifacts.

Build

# Clone vLLM
git clone https://github.com/vllm-project/vllm.git
cd vllm
git checkout v0.17.0

# Create venv and install PyTorch
python3 -m venv ~/vllm-build
source ~/vllm-build/bin/activate
pip install --upgrade pip
pip install torch --index-url https://pypi.jetson-ai-lab.io/jp6/cu126

# Build with SM 8.7 targeting Orin tensor cores
export TORCH_CUDA_ARCH_LIST="8.7"
export MAX_JOBS=9                    # AGX Orin has 12 CPU cores; leave some headroom
export VLLM_TARGET_DEVICE=cuda
export CUDA_HOME=/usr/local/cuda

# --no-build-isolation: uses the Jetson PyTorch you just installed
# instead of pip downloading a generic one into an isolated build env
pip install --no-build-isolation .

Build takes approximately 60-75 minutes (CPU-bound nvcc compilation).

Building the Wheel

To create a distributable wheel from your build:

pip install wheel

# Create a staging directory with just the vLLM package files
mkdir -p /tmp/vllm-staging
cp -r "$(python -c 'import vllm; print(vllm.__path__[0])')" /tmp/vllm-staging/vllm
cp -r "$(python -c 'import importlib.metadata; print(importlib.metadata.packages_distributions()["vllm"][0])')" /tmp/vllm-staging/ 2>/dev/null || \
  cp -r "$(pip show vllm | grep Location | cut -d' ' -f2)/vllm-"*.dist-info /tmp/vllm-staging/

# Remove local path reference (contains your filesystem path)
rm -f /tmp/vllm-staging/vllm-*.dist-info/direct_url.json

# Pack into wheel
wheel pack /tmp/vllm-staging --dest-dir /tmp/wheels
rm -rf /tmp/vllm-staging

Known Issues

Issue Workaround
causal_conv1d_update assertion with --max-num-seqs > 1 Use --max-num-seqs 1
CUBLAS_STATUS_ALLOC_FAILED in systemd Add system CUDA paths to LD_LIBRARY_PATH (see systemd section)
Not enough SMs for max_autotune_gemm warning Harmless — Inductor can't auto-tune for 16 SMs
bfloat16 cast warning Harmless — model weights cast from bf16 to fp16 (see --dtype half)
OOM on startup Lower --gpu-memory-utilization or stop competing GPU processes
--disable-log-requests not recognized Renamed to --no-enable-log-requests in vLLM 0.17

Contributing

Issues and pull requests are welcome. When reporting problems, please include:

  • JetPack version (cat /etc/nv_tegra_release)
  • Device (AGX Orin 64GB, Orin Nano, etc.)
  • Model name and quantization
  • Full error output

Part of the AIquest Research Lab

This project is part of the AIquest research ecosystem. The vLLM Marlin wheel powers the inference backend for:

Pre-built wheel hosted on HuggingFace.

Explore all projects at aiquest.info/research.

Changelog

2026-03-15

  • Documentation: Added Attention Backend section documenting that vLLM uses its bundled vllm_flash_attn (not FlashInfer) on SM 8.7, with native SASS confirmed in the wheel. Documented FlashInfer as an untested alternative with potential +25% throughput improvement.
  • Documentation: Added --max-num-batched-tokens to the launch example (was in the flags table but missing from the example command).
  • Documentation: Added Environment Variables subsection documenting VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and VLLM_ATTENTION_BACKEND.

2026-03-11

  • Fix: Rewrote benchmark.sh to use single-process timing (previous version compared time.perf_counter() across separate Python processes, producing garbage values).
  • Fix: Eliminated shell injection vulnerabilities in benchmark scripts by passing arguments via environment variables instead of string interpolation.
  • Fix: Added set -euo pipefail and input validation to all shell scripts.

2026-03-11 (Initial Release)

  • Pre-built vLLM 0.17.0 wheel with Marlin GPTQ kernels for SM 8.7 (aarch64).
  • Scheduler fast-path patch for single-user batch=1 decode.
  • Systemd service template with security hardening.
  • Streaming and non-streaming benchmark scripts.
  • Published wheel on HuggingFace.

License

MIT

About

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors