Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 72 additions & 43 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Rust builder

FROM lukemathwalker/cargo-chef:latest-rust-1.83 AS chef

WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse

FROM chef as planner

COPY Cargo.toml Cargo.toml
COPY rust-toolchain.toml rust-toolchain.toml
COPY proto proto
Expand Down Expand Up @@ -49,116 +52,140 @@ ARG TARGETPLATFORM

ENV PATH /opt/conda/bin:$PATH

# For build-time CUDA memory resilience
ENV PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git && \
rm -rf /var/lib/apt/lists/*
git \
ninja-build \
wget \
&& rm -rf /var/lib/apt/lists/*

# Add these lines to install a *newer* CMake version
RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends software-properties-common && \
wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null && \
echo "deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" | tee /etc/apt/sources.list.d/kitware.list >/dev/null && \
apt-get update && \
rm -f /etc/apt/sources.list.d/cmake.list && \
apt-get install -y --no-install-recommends cmake

# Install conda
# translating Docker's TARGETPLATFORM into mamba arches
RUN case ${TARGETPLATFORM} in \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
*) MAMBA_ARCH=x86_64 ;; \
esac && \
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
RUN chmod +x ~/mambaforge.sh && \
bash ~/mambaforge.sh -b -p /opt/conda && \
rm ~/mambaforge.sh

# Install pytorch
# On arm64 we exit with an error code
RUN case ${TARGETPLATFORM} in \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
"linux/arm64") exit 1 ;; \
*) /opt/conda/bin/conda update -y conda && \
/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
esac && \
/opt/conda/bin/conda clean -ya

# CUDA kernels builder image
FROM pytorch-install as kernel-builder

ARG MAX_JOBS=2
# This environment variable controls the number of parallel compilation jobs for CUDA kernels.
# It is set to a conservative value (2) by default for stability on machines
# with limited RAM relative to CPU cores, preventing Out-Of-Memory (OOM) crashes during build.
#
# You can adjust this value to optimize build speed based on your system's RAM:
# - If you have more RAM (e.g., 96GB+), you can increase this value (e.g., to 16, 24, or 32)
# to significantly speed up the build. Always monitor RAM usage (htop) to avoid OOM crashes.
# - If you encounter OOM errors even with this value, try reducing it further to 1.
ENV MAX_JOBS=2
# If you encounter OOM errors even with this value, try reducing it to 1.

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
ninja-build cmake \
&& rm -rf /var/lib/apt/lists/*
RUN pip install setuptools_scm --no-cache-dir

# Build Flash Attention CUDA kernels
FROM kernel-builder as flash-att-builder

WORKDIR /usr/src
COPY server/Makefile-flash-att Makefile
RUN make build-flash-attention
RUN make build-flash-attention -j$(MAX_JOBS)

# Build Flash Attention v2 CUDA kernels
FROM kernel-builder as flash-att-v2-builder

WORKDIR /usr/src
COPY server/Makefile-flash-att-v2 Makefile
RUN make build-flash-attention-v2-cuda
RUN make build-flash-attention-v2-cuda -j$(MAX_JOBS)

# Build Transformers exllama kernels
FROM kernel-builder as exllama-kernels-builder

WORKDIR /usr/src
COPY server/exllama_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
RUN MAX_JOBS=$(MAX_JOBS) TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build

# Build Transformers exllama kernels
FROM kernel-builder as exllamav2-kernels-builder

WORKDIR /usr/src
COPY server/exllamav2_kernels/ .
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
RUN MAX_JOBS=$(MAX_JOBS) TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build

# Build Transformers awq kernels
FROM kernel-builder as awq-kernels-builder

WORKDIR /usr/src
COPY server/Makefile-awq Makefile
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq -j$(MAX_JOBS)

# Build Transformers CUDA kernels
FROM kernel-builder as custom-kernels-builder

WORKDIR /usr/src
COPY server/custom_kernels/ .
# Build specific version of transformers
RUN python setup.py build
RUN MAX_JOBS=$(MAX_JOBS) python setup.py build

# Build vllm CUDA kernels
FROM kernel-builder as vllm-builder

WORKDIR /usr/src
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
wget \
&& rm -rf /var/lib/apt/lists/*
RUN DEBIAN_FRONTEND=noninteractive apt purge -y --auto-remove cmake
RUN wget 'https://github.com/Kitware/CMake/releases/download/v3.30.0/cmake-3.30.0-linux-x86_64.tar.gz'
RUN tar xzvf 'cmake-3.30.0-linux-x86_64.tar.gz'
RUN ln -s "$(pwd)/cmake-3.30.0-linux-x86_64/bin/cmake" /usr/local/bin/cmake
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
COPY server/Makefile-vllm Makefile
# Build specific version of vllm
RUN make build-vllm-cuda
RUN make build-vllm-cuda -j$(MAX_JOBS)

# Build megablocks kernels
FROM kernel-builder as megablocks-kernels-builder

WORKDIR /usr/src
COPY server/Makefile-megablocks Makefile
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
RUN make build-megablocks
RUN make build-megablocks -j$(MAX_JOBS)

# Build punica CUDA kernels
FROM kernel-builder as punica-builder
WORKDIR /usr/src

COPY server/punica_kernels/ .
# Build specific version of punica
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
RUN python setup.py build
RUN MAX_JOBS=$(MAX_JOBS) python setup.py build

# Build eetq kernels
FROM kernel-builder as eetq-kernels-builder

WORKDIR /usr/src
COPY server/Makefile-eetq Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq -j$(MAX_JOBS)

# LoRAX base image
FROM nvidia/cuda:12.4.0-base-ubuntu22.04 as base
Expand All @@ -185,32 +212,36 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY --from=pytorch-install /opt/conda /opt/conda

# Copy build artifacts from flash attention builder
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from flash attention v2 builder
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from custom kernels builder
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from exllama kernels builder
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from exllamav2 kernels builder
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from awq kernels builder
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy builds artifacts from vllm builder
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy builds artifacts from punica builder
COPY --from=punica-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=punica-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from megablocks builder
COPY --from=megablocks-kernels-builder /usr/src/megablocks/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=megablocks-kernels-builder /usr/src/megablocks/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from eetq builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Install flash-attention dependencies
RUN pip install einops --no-cache-dir
Expand Down Expand Up @@ -238,7 +269,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
g++ \
&& rm -rf /var/lib/apt/lists/*


# Final image
FROM base
LABEL source="https://github.com/predibase/lorax"
Expand All @@ -250,7 +280,6 @@ RUN chmod +x entrypoint.sh
COPY sync.sh sync.sh
RUN chmod +x sync.sh


RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" && \
unzip awscliv2.zip && \
sudo ./aws/install && \
Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ _LoRAX: Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs_

LoRAX (LoRA eXchange) is a framework that allows users to serve thousands of fine-tuned models on a single GPU, dramatically reducing the cost of serving without compromising on throughput or latency.

---

**🚀 Start Here: For a Robust & Reliable LoRAX Deployment**

While this `README.md` provides a general overview, setting up a performant LoRAX server involves specific hardware, software, and environment configurations. To ensure a smooth, "impossible-to-fail" deployment experience, we highly recommend consulting our detailed **[LoRAX Deployment Playbook](lorax_deployment_playbook.md)**. This guide covers:

* **[Bulletproof Host System Setup](lorax_deployment_playbook.md#phase-1-host-setup):** NVIDIA drivers, Docker, `nvidia-container-toolkit`, and crucial user permissions.
* **[GPU VRAM Considerations](lorax_deployment_playbook.md#phase-2-deploy-lorax):** Understanding LLM memory requirements and selecting compatible models for your hardware.
* **[Pre-Built vs. Source Deployment](lorax_deployment_playbook.md#phase-2-deploy-lorax):** Choosing the fastest path or building from source with all CUDA kernels.
* **[Common Pitfalls & Troubleshooting](lorax_deployment_playbook.md#troubleshooting-guide):** Solutions for Hugging Face authentication, model download stalls, and more.

---

## 📖 Table of contents

- [📖 Table of contents](#-table-of-contents)
Expand Down Expand Up @@ -59,6 +72,9 @@ Base models can be loaded in fp16 or quantized with `bitsandbytes`, [GPT-Q](http

Supported adapters include LoRA adapters trained using the [PEFT](https://github.com/huggingface/peft) and [Ludwig](https://ludwig.ai/) libraries. Any of the linear layers in the model can be adapted via LoRA and loaded in LoRAX.

**⚙️ Model Compatibility & VRAM:** Selecting the right model for your GPU's VRAM is crucial. Not all quantized models are plug-and-play due to varying toolchains. For detailed guidance on VRAM limitations and troubleshooting quantized model errors (e.g., `CUDA out of memory`, `RuntimeError`), refer to **[Phase 2: Deploy LoRAX](lorax_deployment_playbook.md#phase-2-deploy-lorax)** in the LoRAX Deployment Playbook.


## 🏃‍♂️ Getting Started

We recommend starting with our pre-built Docker image to avoid compiling custom CUDA kernels and other dependencies.
Expand All @@ -72,6 +88,8 @@ The minimum system requirements need to run LoRAX include:
- Linux OS
- Docker (for this guide)

**🚨 Critical Setup Note:** Meeting these requirements can be complex. For a step-by-step, verified guide on installing GPU drivers, Docker Engine, and `nvidia-container-toolkit` (including essential user permissions), please follow **[Phase 1: Host Setup](lorax_deployment_playbook.md#phase-1-host-setup)** in the LoRAX Deployment Playbook. Incorrect setup here is the most common cause of deployment failures.

### Launch LoRAX Server

#### Prerequisites
Expand All @@ -80,6 +98,8 @@ Then
- `sudo systemctl daemon-reload`
- `sudo systemctl restart docker`

**💡 For the most reliable and fully explained `docker run` command, including critical flags (`-e HUGGING_FACE_HUB_TOKEN`, `--user`), model selection based on GPU VRAM, and troubleshooting common issues like model download stalls or quantized model compatibility, refer to our comprehensive guide: [Phase 2: Deploy LoRAX](lorax_deployment_playbook.md#phase-2-deploy-lorax) and [Phase 3: Test the API](lorax_deployment_playbook.md#phase-3-test-the-api) in the LoRAX Deployment Playbook.**

```shell
model=mistralai/Mistral-7B-Instruct-v0.1
volume=$PWD/data
Expand Down
Loading