PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"
Uv can be installed as follows, depending on your operating system.
macOS and Linux
curl -LsSf https://astral.sh/uv/install.sh | shor
wget -qO- https://astral.sh/uv/install.sh | shNote: For more installation options, please refer to the official uv documentation.
To initiate a pyproject.toml file use
uv init --bareRecommended to use Python 3.13
uv python pin 3.13Import runtime dependencies from your requirements file with uv add -r requirements.txt to record them in pyproject.toml and update the lock and environment.
uv add -r requirements.txtCreate the environment and lock everything with uv sync, which resolves dependencies, writes uv.lock, and installs them into a project .venv.
uv syncYou can install new packages, that are not specified in the pyproject.toml via uv add, for example:
uv add packagingAnd you can remove packages via on Windows (PowerShell):
uv remove packaginguv run -- python .\check_cuda.pyuv run python train.pySkipping the uv run command
If you find typing uv run cumbersome, you can manually activate the virtual environment as described below.
On macOS/Linux:
source .venv/bin/activateThen, you can run scripts via
python script.pyand launch JupyterLab via
jupyter lab# Import the torch library, which provides tools for machine learning
import torch
# Import the Jamba model from the jamba.model module
from jamba.model import Jamba
# Create a tensor of random integers between 0 and 100, with shape (1, 100)
# This simulates a batch of tokens that we will pass through the model
x = torch.randint(0, 100, (1, 100))
# Initialize the Jamba model with the specified parameters
# dim: dimensionality of the input data
# depth: number of layers in the model
# num_tokens: number of unique tokens in the input data
# d_state: dimensionality of the hidden state in the model
# d_conv: dimensionality of the convolutional layers in the model
# heads: number of attention heads in the model
# num_experts: number of expert networks in the model
# num_experts_per_token: number of experts used for each token in the input data
model = Jamba(
dim=512,
depth=6,
num_tokens=100,
d_state=256,
d_conv=128,
heads=8,
num_experts=8,
num_experts_per_token=2,
)
# Perform a forward pass through the model with the input data
# This will return the model's predictions for each token in the input data
output = model(x)
# Print the model's predictions
print(output)MIT