Skip to content
forked from kyegomez/Jamba

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"

License

Notifications You must be signed in to change notification settings

cataluna84/jamba

 
 

Repository files navigation

Jamba

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"

Native uv Python and package management

1. Install uv

Uv can be installed as follows, depending on your operating system.

macOS and Linux

curl -LsSf https://astral.sh/uv/install.sh | sh

or

wget -qO- https://astral.sh/uv/install.sh | sh

Note: For more installation options, please refer to the official uv documentation.

2. Install Python packages and dependencies

To initiate a pyproject.toml file use

uv init --bare

Recommended to use Python 3.13

uv python pin 3.13

Import runtime dependencies from your requirements file with uv add -r requirements.txt to record them in pyproject.toml and update the lock and environment.

uv add -r requirements.txt

Create the environment and lock everything with uv sync, which resolves dependencies, writes uv.lock, and installs them into a project .venv.

uv sync

You can install new packages, that are not specified in the pyproject.toml via uv add, for example:

uv add packaging

And you can remove packages via on Windows (PowerShell):

uv remove packaging

Check CUDA

uv run -- python .\check_cuda.py

Train

uv run python train.py

Skipping the uv run command

If you find typing uv run cumbersome, you can manually activate the virtual environment as described below.

On macOS/Linux:

source .venv/bin/activate

Then, you can run scripts via

python script.py

and launch JupyterLab via

jupyter lab

Usage

# Import the torch library, which provides tools for machine learning
import torch

# Import the Jamba model from the jamba.model module
from jamba.model import Jamba

# Create a tensor of random integers between 0 and 100, with shape (1, 100)
# This simulates a batch of tokens that we will pass through the model
x = torch.randint(0, 100, (1, 100))

# Initialize the Jamba model with the specified parameters
# dim: dimensionality of the input data
# depth: number of layers in the model
# num_tokens: number of unique tokens in the input data
# d_state: dimensionality of the hidden state in the model
# d_conv: dimensionality of the convolutional layers in the model
# heads: number of attention heads in the model
# num_experts: number of expert networks in the model
# num_experts_per_token: number of experts used for each token in the input data
model = Jamba(
    dim=512,
    depth=6,
    num_tokens=100,
    d_state=256,
    d_conv=128,
    heads=8,
    num_experts=8,
    num_experts_per_token=2,
)

# Perform a forward pass through the model with the input data
# This will return the model's predictions for each token in the input data
output = model(x)

# Print the model's predictions
print(output)

License

MIT

About

PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 92.7%
  • Shell 5.2%
  • Dockerfile 2.1%