Tunix supports the following models:
| Model | Sizes |
|---|---|
| Gemma | 2B, 7B, 9B |
| Gemma 2 | 2B, 9B |
| Gemma 3 | 270M, 1B, 4B, 12B, 27B |
| Llama 3 | 70B, 405B |
| Llama 3.1 | 8B, 70B, 405B |
| Llama 3.2 | 1B, 3B |
| Qwen 2.5 | 0.5B, 1.5B, 3B, 7B |
| Qwen 3 | 0.6B, 1.7B, 4B, 8B, 14B, 30B, 32B |
The model configurations and checkpoints should be accessible from Huggingface and Kaggle. For example, following snippets shows how to load the Gemma 2B model from Huggingface:
ignore_patterns = [
"*.pth", # Ignore PyTorch .pth weight files
]
MODEL_PATH = snapshot_download(repo_id="google/gemma-2-2b-it", ignore_patterns=ignore_patterns)You can also store model checkpoints to GCS. So if you have GCS bucket resources and have uploaded the model checkpoints there, you can access them as well.
MODEL_PATH = "gs://<your-bucket-dev>/your-model-checkpoints"Once you have an accessible model path from one of the above approach, you are able to load it through Tunix model loading API as following:
config = model_lib.ModelConfig.gemma2_2b()
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
with mesh:
gemma = params_lib.create_model_from_safe_tensors(MODEL_PATH, config, mesh)Model optimization is critical for efficient model execution. This includes optimal shardings on TPUs, optimization with Pallas kernels, etc. Tunix provides a lightweight suite of models which is only optimized to some extent. Integration of Tunix and Maxtext enables users to run the RL workloads with fully optimized models. Refer to the single-host and multi-host tutorial on how to run an optimized model RL workload with Maxtext and Tunix.
You can add new models to Tunix codebase by following the Tunix convention.
If the new model falls into one of the existing model families (e.g. Gemma, Llama, etc.) then adding a new model doesn't need to create new files. You just need to add the model specs to the corresponding model family. Take a look at the Llama examples. If the new model is from a new model family that Tunix hasn't supported yet. You will need to follow the design and APIs as the existing model families to create the model implementation.
Adding the new model needs to following the naming convention that Tunix supports so that AutoModel(as described below) could work correctly. We use the pattern of <model_family><major_version>p<minor_version>_<model_size>to name a model. For example, the Llama3.2 1b model is named as llama3p2_1b while a Qwen2.5 1.5b model is named as qwen2p5_1p5b.
AutoModel provides a unified interface for instantiating Tunix models from
pretrained checkpoints, similar to the Huggingface AutoModel API. It allows
you to load a model simply by providing its model_id, handling the download
and initialization for you.
To load a model, use the AutoModel.from_pretrained method with the model
identifier and your JAX sharding mesh. By default this will download the model
from Huggingface.
from tunix.models.automodel import AutoModel
import jax
# 1. Define your mesh
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
# 2. Load the model
# By default, this downloads from Huggingface.
model, model_path = AutoModel.from_pretrained(
model_id="google/gemma-2-2b-it", # Using HF id as model_id
mesh=mesh
)
print(f"Model loaded from: {model_path}")You can load models from different sources (e.g., Kaggle, GCS, etc.) using the
model_source argument.
This is the default choice (ModelSource.HUGGINGFACE) as shown in the
example above.
For Kaggle, you must provide the model_id which is the Huggingface identifier or model_config_id (see Naming Conventions) to determine the model configuration and the model_path which is the Kaggle
Hub model identifier (used to download the model from Kaggle).
model, model_path = AutoModel.from_pretrained(
model_id="gemma2_2b_it", # Using model_config_id as model_id
mesh=mesh,
model_source=ModelSource.KAGGLE,
model_path="google/gemma-2/flax/gemma2-2b-it",
)For example the model_path for the google/gemma-2/flax/gemma2-2b-it is extracted on Kaggle as shown below
For GCS, you must provide the model_id which is the Huggingface identifier or model_config_id (see Naming Conventions) to determine the model configuration and the model_path (the actual GCS
location).
model, model_path = AutoModel.from_pretrained(
model_id="gemma2_2b_it", # Using model_config_id as model_id
mesh=mesh,
model_source=ModelSource.GCS,
model_path="gs://my-bucket/gemma-2-2b-it"
)Optionally, you can also provide the model_download_path argument, which
specifies where the model is to be downloaded to. Depending on the
model_source the effect of specifying this variable is different:
- Huggingface: Files are downloaded directly to this directory.
- Kaggle: Sets the
KAGGLEHUB_CACHEenvironment variable to this path. - GCS: No-op.
- Internal: Files are copied to this directory. If omitted, the model is loaded directly from the
model_path. This mode (Internal) is not supported in OSS version.
This section outlines the naming conventions used within Tunix for model identification and configuration. These conventions ensure consistency when loading models from various sources like Huggingface or Kaggle.
The ModelNaming dataclass handles the parsing and standardization of model names.
-
model_id: This is a unique identifier used to identifty the model in mind and extract the family, version, and desired config from. Tunix support two identifiers as themodel_id:- Huggingface (HF) IDs: The full model name identifier (case sensitive), as it appears on Huggingface, including the parent directory.
- Extracting model_id from HF: For example,
meta-llama/Llama-3.1-8Bis extracted as shown below:
{: width="75%"}
- Native Tunix model_configs: the
model_config_idrepresenting the exact config from the model class can be used directly as themodel_id. In this case it will also be treated as themodel_name.
- Extracting model_id from model_config_id: In this case, you would need to refer to the source code (
model.py) for each model family and select the config id from theModelConfigclass, for examplellama3p1_8bfrom the llama model code.
-
model_name: The unique full name identifier of the model. This corresponds to the full name and should match exactly with the model name used in Hugging Face or Kaggle. It is typically all lowercase and formatted as<model-family>-<model-version>(when HF is used for model_id) or<model-family>_<model-version>(when model_config_id is used for model_id) .- Example for HF as model_id:
gemma-2b,llama-3.1-8b,gemma-2-2b-it. - Example for model_config_id as model_id:
gemma_2b,llama3p1_8b,gemma2_2b_it.
- Example for HF as model_id:
-
model_family: The standardized model family. Unnecessary hyphens are removed, and versions are standardized (e.g., replacing dot withp).- Example:
gemma,gemma2,qwen2p5. - Conversion:
gemma-2->gemma2,qwen2.5->qwen2p5.
- Example:
-
model_version: The standardized version of the model family (lowercase, hyphens to underscores, dots top). This is usually the second portion of themodel_nameand includes size information or tuning variants (e.g., "it" for instruction tuned).- Example:
2b_it. - Conversion:
2b-it->2b_it
- Example:
-
model_config_category: The Python class name of theModelConfigclass. This groups models that share the same configuration structure.- Example: Both
gemmaandgemma2models fall under thegemmacategory, with theModelConfigclass defined inmodels/gemma/model.py.
- Example: Both
-
model_config_id: The standardized configuration ID used within theModelConfigclass. It is composed of themodel_familyandmodel_version.- Example:
gemma_2b_itorqwen2p5_0p5b.
- Example:
You can initialize ModelNaming by providing either the model_id or the
model_name. If model_id is provided, the model_name is inferred as the
last segment of the model_id. If model_name is provided, it is used
directly. All other naming attributes are then automatically derived and
validated.
