Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
206 changes: 206 additions & 0 deletions roberta/sequence_classification/jax/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

"""
RoBERTa model loader implementation for sequence classification.
"""

from typing import Optional

from ....base import ForgeModel
from ....config import (
LLMModelConfig,
ModelInfo,
ModelGroup,
ModelTask,
ModelSource,
Framework,
StrEnum,
)
from ....tools.jax_utils import cast_hf_model_to_type
import flax.nnx as nnx
from jax.sharding import PartitionSpec
import jax.numpy as jnp
import numpy as np


class ModelVariant(StrEnum):
"""Available RoBERTa model variants."""

BASE = "base"
LARGE = "large"


class ModelLoader(ForgeModel):
"""RoBERTa model loader implementation for masked language modeling."""

_VARIANTS = {
ModelVariant.BASE: LLMModelConfig(
pretrained_model_name="FacebookAI/roberta-base",
),
ModelVariant.LARGE: LLMModelConfig(
pretrained_model_name="FacebookAI/roberta-large",
),
}

DEFAULT_VARIANT = ModelVariant.BASE

sample_text = "Hello, my dog is cute"

def __init__(self, variant: Optional[ModelVariant] = None):
super().__init__(variant)
self._tokenizer = None
self._model_name = self._variant_config.pretrained_model_name

@classmethod
def _get_model_info(cls, variant: Optional[ModelVariant] = None) -> ModelInfo:
"""Implementation method for getting model info with validated variant.

Args:
variant: Optional ModelVariant specifying which variant to use.
If None, DEFAULT_VARIANT is used.

Returns:
ModelInfo: Information about the model and variant
"""

return ModelInfo(
model="roberta",
variant=variant,
group=ModelGroup.GENERALITY,
task=ModelTask.NLP_TEXT_CLS,
source=ModelSource.EASYDEL,
framework=Framework.JAX,
)

def _load_tokenizer(self, dtype_override=None):
"""Load the tokenizer for the model.

Args:
dtype_override: Optional dtype to override the default dtype.

Returns:
Tokenizer: The tokenizer for the model
"""

from transformers import AutoTokenizer

# Initialize tokenizer with dtype_override if provided
tokenizer_kwargs = {}
if dtype_override is not None:
tokenizer_kwargs["dtype"] = dtype_override

self._tokenizer = AutoTokenizer.from_pretrained(
self._model_name, **tokenizer_kwargs
)

return self._tokenizer

def load_model(self, dtype_override=None):
"""Load and return the RoBERTa model instance for this instance's variant.

Args:
dtype_override: Optional dtype to override the default dtype.

Returns:
model: The loaded model instance
"""

from easydel import AutoEasyDeLModelForSequenceClassification

# Ensure tokenizer is loaded
if self._tokenizer is None:
self._load_tokenizer(dtype_override)

model_kwargs = {}
if dtype_override is not None:
model_kwargs["dtype"] = dtype_override

partition_rules = ((r".*", PartitionSpec()),)

# Load model
model = AutoEasyDeLModelForSequenceClassification.from_pretrained(
self._model_name,
partition_rules=partition_rules,
**model_kwargs,
)

# Cast the model to the dtype_override if provided
if dtype_override is not None:
model = cast_hf_model_to_type(model, dtype_override)

return model

def load_inputs(self, dtype_override=None, mesh=None):
"""Load and return the inputs for the model.

Args:
dtype_override: Optional dtype to override the default dtype.
Mesh: Optional device mesh for sharding (DataParallel mode).
Returns:
inputs: Input tensors that can be fed to the model.
"""

if mesh is not None:
# For multi-device, use a fixed batch size that's divisible by device count
# This matches the original test which used batch_size=8
num_devices = np.prod(list(mesh.shape.values())) if mesh.shape else 1
batch_size = 8 # Fixed batch size, will be sharded across devices
# Ensure batch size is divisible by number of devices
if batch_size % num_devices != 0:
batch_size = num_devices * (batch_size // num_devices + 1)
else:
# Default to 8 for single device too, for consistency
batch_size = 8

# Ensure tokenizer is initialized
if self._tokenizer is None:
self._load_tokenizer(dtype_override=dtype_override)

# Create tokenized inputs for the masked language modeling task
inputs = self._tokenizer(
self.sample_text,
return_tensors="jax",
)

inputs["input_ids"] = jnp.repeat(inputs["input_ids"], batch_size, axis=0)
inputs["attention_mask"] = jnp.repeat(
inputs["attention_mask"], batch_size, axis=0
)

return inputs

def get_input_activations_partition_spec(self, mesh, axis_name="X"):
"""Get partition specification for input activations.

Args:
mesh: The device mesh for sharding.
axis_name: Name of the sharding axis.

Returns:
PartitionSpec for input activations (sharded on batch dimension)
"""
if np.prod(list(mesh.shape.values())) == 1:
return PartitionSpec()

return PartitionSpec(axis_name)

def load_parameters_partition_spec(
self,
model_for_multichip=None,
cpu_mesh=None,
input_activations_partition_specs=None,
inputs=None,
dtype_override=None,
):
# Get the model state
state = nnx.split(model_for_multichip)[1]

partition_rules = ((r".*", PartitionSpec()),) # Everything replicated

from infra.utilities import make_easydel_parameters_partition_specs

return make_easydel_parameters_partition_specs(
model_state=state, partition_rules=partition_rules
)
3 changes: 3 additions & 0 deletions roberta/sequence_classification/jax/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
git+https://github.com/erfanzar/EasyDeL.git@77ced9d2f2ab6a3d705936d26112eb97d9f9e64a

eformer==0.0.62