DistilP is a Python library for MILP-based layer/expert assignment for distributed inference across heterogeneous devices. It profiles device and model characteristics, then solves an optimization problem to determine the optimal layer distribution for distributed Large Language Model (LLM) inference.
- Device Profiling: Measure CPU/GPU throughput, memory capacity, disk I/O, and communication characteristics
- Model Profiling: Analyze layer-wise memory requirements and computational costs for LLMs
- HALDA Solver: Mixed Integer Linear Programming solver for optimal layer assignment across heterogeneous devices
- Pydantic-based: Type-safe profile schemas with automatic validation and serialization
DistilP requires:
- Python 3.12 or higher
- uv package manager
# Install solver only
uv add distilp[solver]
# Install profiler only (macOS with MLX support)
uv add distilp[profiler]
# Install both solver and profiler
uv add distilp[solver,profiler]
# Install with plotting support
uv add distilp[solver,solver-plotting]Clone the repository and install with development dependencies:
git clone https://github.com/firstbatchxyz/distilp.git
cd distilp
# Install all dependencies for development
uv sync --extra devThis installs all optional dependencies including pytest, matplotlib, and profiling tools.
The typical workflow involves three steps:
- Profile your devices - Generate device performance profiles
- Profile your model - Generate model layer characteristics
- Run the solver - Compute optimal layer distribution
Profile the current device to measure its computational capabilities:
# Profile device for a specific model
uv run profiler device -r "Qwen/Qwen3-4B-MLX-8bit" -o device_profile.json
# Profile with higher batch sizes (default max_batch_exp is 2, meaning 2^2=4)
uv run profiler device -r "Qwen/Qwen3-4B-MLX-8bit" -o device_profile.json --max-batch-exp 6This will output a JSON file, at the path -o. If no -o is given, it will print to console.
Profile a model to measure layer-wise memory and computational requirements:
# Profile model with default batch sizes (1,2,4)
uv run profiler model -r "Qwen/Qwen3-4B-MLX-8bit" -o model_profile.json
# Profile with custom batch sizes and sequence length
uv run profiler model -r "Qwen/Qwen3-4B-MLX-8bit" \
-o model_profile.json \
--batches "1,2,4,8" \
--sequence 1024This will also output a JSON.
Use the profiles to compute the optimal layer distribution:
# Run solver with a profile folder
uv run solver --profile test/profiles/hermes_70b
# Run solver with separate device and model files
uv run solver \
--devices device_profile_1.json device_profile_2.json \
--model model_profile.json
# Run with custom solver parameters
uv run solver --profile hermes_70b \
--time-limit 10 \
--max-iters 20 \
--mip-gap 0.001
# Save solution to JSON
uv run solver --profile hermes_70b --save-solution solution.json
# Disable plotting
uv run solver --profile hermes_70b --no-plotOutput: Optimal layer distribution showing:
- Number of pipeline stages (k)
- Layer assignment per device (w)
- Objective value (latency estimate)
- Device grouping for optimization
DistilP can be used programmatically in Python scripts. This is useful for integrating profiling and solving into automated workflows or custom applications.
from distilp.profiler import profile_device, profile_model
# Profile the current device
device_profile = profile_device(
repo_id="Qwen/Qwen3-4B-MLX-8bit",
max_batch_exp=6, # Profile up to batch size 2^6=64
debug=0
)
# Profile a model
model_profile = profile_model(
repo_id="Qwen/Qwen3-4B-MLX-8bit",
batch_sizes=[1, 2, 4, 8],
sequence_length=512,
debug=0
)
print(f"Device: {device_profile.name}")
print(f"Model: {model_profile.L} layers, {model_profile.V} vocab size")Profiles are Pydantic models with built-in serialization:
from distilp.common import DeviceProfile, ModelProfileSplit
import json
# Save profiles to JSON
with open("device_profile.json", "w") as f:
f.write(device_profile.model_dump_json(indent=2))
with open("model_profile.json", "w") as f:
f.write(model_profile.model_dump_json(indent=2))
# Load profiles from JSON
with open("device_profile.json", "r") as f:
device_profile = DeviceProfile.model_validate_json(f.read())
with open("model_profile.json", "r") as f:
model_profile = ModelProfileSplit.model_validate_json(f.read())from distilp.solver import halda_solve
from distilp.common import DeviceProfile, ModelProfile
# Load or create device and model profiles
devices = [device_profile_1, device_profile_2] # List of DeviceProfile objects
model = model_profile # ModelProfile or ModelProfileSplit object
# Run the HALDA solver
result = halda_solve(
devs=devices,
model=model,
k_candidates=None, # None = try all factors of L
mip_gap=1e-4, # MIP gap tolerance
plot=True, # Show k vs objective plot
kv_bits="4bit", # KV cache quantization
)
# Access results
print(f"Optimal k: {result.k}")
print(f"Objective value: {result.obj_value:.6f}")
print(f"Layer distribution: {result.w}")
print(f"Device grouping: {result.sets}")
# Access per-device assignments
for i, (device, w_i, n_i) in enumerate(zip(devices, result.w, result.n)):
print(f"Device {i} ({device.name}): {w_i} layer groups, {n_i} layers total")We use pytest:
uv run pytest
# Run with verbose output
uv run pytest -v
# Run specific test file
uv run pytest test/test_integration.py -v
# Run specific test function (can give file name too)
uv run pytest -k test_method_nameYou can find the license here.