Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mamba/causal_lm/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

from .loader import ModelVariant, ModelLoader
245 changes: 245 additions & 0 deletions mamba/causal_lm/jax/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

"""
Mamba model loader implementation for causal language modeling.
"""

from typing import Optional

from ....base import ForgeModel
from ....config import (
LLMModelConfig,
ModelInfo,
ModelGroup,
ModelTask,
ModelSource,
Framework,
StrEnum,
Parallelism,
)
from ....tools.jax_utils import cast_hf_model_to_type
import flax.nnx as nnx
from jax.sharding import PartitionSpec
import jax.numpy as jnp
import numpy as np


class ModelVariant(StrEnum):
"""Available MAMBA model variants."""

MAMBA_370M = "mamba-370m-hf"
MAMBA_790M = "mamba-790m-hf"
MAMBA_1_4B = "mamba-1.4b-hf"
MAMBA_2_8B = "mamba-2.8b-hf"


class ModelLoader(ForgeModel):
"""Mamba model loader implementation for causal language modeling."""

# Dictionary of available model variants
_VARIANTS = {
ModelVariant.MAMBA_370M: LLMModelConfig(
pretrained_model_name="state-spaces/mamba-370m-hf",
),
ModelVariant.MAMBA_790M: LLMModelConfig(
pretrained_model_name="state-spaces/mamba-790m-hf",
),
ModelVariant.MAMBA_1_4B: LLMModelConfig(
pretrained_model_name="state-spaces/mamba-1.4b-hf",
),
ModelVariant.MAMBA_2_8B: LLMModelConfig(
pretrained_model_name="state-spaces/mamba-2.8b-hf",
),
}

# Default variant to use
DEFAULT_VARIANT = ModelVariant.MAMBA_790M

sample_text = "Hello there fellow traveller"

def __init__(self, variant: Optional[ModelVariant] = None):
"""Initialize ModelLoader with specified variant.

Args:
variant: Optional ModelVariant specifying which variant to use.
If None, DEFAULT_VARIANT is used.
"""
super().__init__(variant)
self._tokenizer = None
self._model_name = self._variant_config.pretrained_model_name

@classmethod
def _get_model_info(cls, variant: Optional[ModelVariant] = None) -> ModelInfo:
"""Implementation method for getting model info with validated variant.

Args:
variant: Optional ModelVariant specifying which variant to use.
If None, DEFAULT_VARIANT is used.

Returns:
ModelInfo: Information about the model and variant
"""

return ModelInfo(
model="mamba",
variant=variant,
group=ModelGroup.GENERALITY,
task=ModelTask.NLP_CAUSAL_LM,
source=ModelSource.EASYDEL,
framework=Framework.JAX,
)

def _load_tokenizer(self, dtype_override=None):
"""Load tokenizer for the current variant.

Args:
dtype_override: Optional dtype to override the tokenizer's default dtype.

Returns:
tokenizer: The loaded tokenizer instance
"""

from transformers import AutoTokenizer

# Initialize tokenizer with dtype_override if provided
tokenizer_kwargs = {}
if dtype_override is not None:
tokenizer_kwargs["dtype"] = dtype_override

# Load the tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self._model_name, **tokenizer_kwargs
)

return self._tokenizer

def load_model(self, dtype_override=None):
"""Load and return the Mamba model instance for this instance's variant.

Args:
dtype_override: Optional dtype to override the model's default dtype.
If not provided, the model will use its default dtype (typically float32).

Returns:
model: The loaded model instance
"""

from easydel import AutoEasyDeLModelForCausalLM
import jax

# Ensure tokenizer is loaded
if self._tokenizer is None:
self._load_tokenizer(dtype_override)

# Initialize model kwargs
model_kwargs = {}

# Determine the target dtype
if dtype_override is not None:
model_kwargs["dtype"] = dtype_override

partition_rules = ((r".*", PartitionSpec()),)

# Temporarily disable x64 mode during model loading to prevent dtype mismatch in lax.clamp
# The model will still be loaded with the target_dtype specified above
original_x64_state = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", False)

try:
model = AutoEasyDeLModelForCausalLM.from_pretrained(
self._model_name, partition_rules=partition_rules, **model_kwargs
)
finally:
# Restore original x64 state
jax.config.update("jax_enable_x64", original_x64_state)

# Cast the model to the dtype_override if provided
if dtype_override is not None:
model = cast_hf_model_to_type(model, dtype_override)

return model

def load_inputs(self, dtype_override=None, mesh=None):
"""Load and return sample inputs for the Mamba model with this instance's variant settings.

Args:
dtype_override: Optional dtype to override the model's default dtype.
mesh: Optional device mesh for sharding (DataParallel mode).
Returns:
inputs: Input tensors that can be fed to the model.
"""

if mesh is not None:
# For multi-device, use a fixed batch size that's divisible by device count
# This matches the original test which used batch_size=8
num_devices = np.prod(list(mesh.shape.values())) if mesh.shape else 1
batch_size = 8 # Fixed batch size, will be sharded across devices
# Ensure batch size is divisible by number of devices
if batch_size % num_devices != 0:
batch_size = num_devices * (batch_size // num_devices + 1)
else:
# Default to 8 for single device too, for consistency
batch_size = 8

# Ensure tokenizer is initialized
if self._tokenizer is None:
self._load_tokenizer(dtype_override)

# Create tokenized inputs
inputs = self._tokenizer(self.sample_text, return_tensors="jax")

input_ids = jnp.repeat(inputs.input_ids, batch_size, axis=0)
return {"input_ids": input_ids}

def get_input_activations_partition_spec(self, mesh, parallelism, axis_name="X"):
"""Get partition specification for input activations.

Args:
mesh: The device mesh for sharding.
parallelism: The level of parallelism for sharding.
axis_name: The name of the mesh axis to use for sharding.

Returns:
PartitionSpec for input activations (sharded on batch dimension)
"""
if (
parallelism.name == Parallelism.TENSOR_PARALLEL.name
or np.prod(list(mesh.shape.values())) == 1
):
return (PartitionSpec(),)

return (PartitionSpec(axis_name),)

def load_parameters_partition_spec(
self,
model_for_multichip,
parallelism,
axis_name="X",
cpu_mesh=None,
input_activations_partition_specs=None,
inputs=None,
dtype_override=None,
):
# Get the model state
state = nnx.split(model_for_multichip)[1]

if (
parallelism.name == Parallelism.DATA_PARALLEL.name
or parallelism.name == Parallelism.SINGLE_DEVICE.name
):
# In data parallel mode, use fully replicated partitioning
partition_rules = ((r".*", PartitionSpec()),)
else:
# Use EasyDel's MambaConfig to get proper partition rules
from easydel.modules.mamba import MambaConfig

mamba_config = MambaConfig()
partition_rules = mamba_config.get_partition_rules()

from infra.utilities import make_easydel_parameters_partition_specs

return make_easydel_parameters_partition_specs(
model_state=state, partition_rules=partition_rules, axis_name=axis_name
)
3 changes: 3 additions & 0 deletions mamba/causal_lm/jax/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
git+https://github.com/erfanzar/EasyDeL.git@77ced9d2f2ab6a3d705936d26112eb97d9f9e64a

eformer==0.0.62