Skip to content

Commit 3684d4a

Browse files
Reformatted pixtal/loader.py and fixed the load_inputs function
1 parent e195c7f commit 3684d4a

File tree

1 file changed

+68
-35
lines changed

1 file changed

+68
-35
lines changed

mistral/pixtral/pytorch/loader.py

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,55 @@
44
"""
55
Mistral Pixtral model loader implementation
66
"""
7-
8-
97
import torch
10-
from transformers import LlavaForConditionalGeneration # , AutoProcessor
8+
from transformers import LlavaForConditionalGeneration, AutoProcessor
119
from ....config import (
1210
ModelInfo,
1311
ModelGroup,
1412
ModelTask,
1513
ModelSource,
1614
Framework,
15+
StrEnum,
16+
ModelConfig,
1717
)
18+
from typing import Optional
1819
from ....base import ForgeModel
20+
from ....tools.utils import cast_input_to_type
21+
22+
23+
class ModelVariant(StrEnum):
24+
"""Available Pixtral model variants."""
25+
26+
PIXTRAL_12B = "pixtral-12b"
1927

2028

2129
class ModelLoader(ForgeModel):
2230
"""Pixtral model loader implementation."""
2331

24-
def __init__(self, variant=None):
32+
# Dictionary of available model variants
33+
_VARIANTS = {
34+
ModelVariant.PIXTRAL_12B: ModelConfig(
35+
pretrained_model_name="mistral-community/pixtral-12b",
36+
),
37+
}
38+
39+
# Default variant to use
40+
DEFAULT_VARIANT = ModelVariant.PIXTRAL_12B
41+
42+
def __init__(self, variant: Optional[ModelVariant] = None):
2543
"""Initialize ModelLoader with specified variant.
2644
2745
Args:
2846
variant: Optional string specifying which variant to use.
2947
If None, DEFAULT_VARIANT is used.
3048
"""
3149
super().__init__(variant)
32-
33-
# Configuration parameters
34-
self.model_name = "mistral-community/pixtral-12b"
35-
# self.processor = None
50+
self.processor = None
51+
self.model = None
52+
self.config = None
3653

3754
@classmethod
38-
def _get_model_info(cls, variant_name: str = None):
55+
def _get_model_info(cls, variant: Optional[ModelVariant] = None) -> ModelInfo:
3956
"""Get model information for dashboard and metrics reporting.
4057
4158
Args:
@@ -44,17 +61,25 @@ def _get_model_info(cls, variant_name: str = None):
4461
Returns:
4562
ModelInfo: Information about the model and variant
4663
"""
47-
if variant_name is None:
48-
variant_name = "base"
4964
return ModelInfo(
50-
model="mistral-community/pixtral-12b",
51-
variant=variant_name,
65+
model="mistral-community",
66+
variant=variant,
5267
group=ModelGroup.RED,
5368
task=ModelTask.MM_VISUAL_QA,
5469
source=ModelSource.HUGGING_FACE,
5570
framework=Framework.TORCH,
5671
)
5772

73+
def _load_processor(self, dtype_override=None):
74+
"""Load processor for the current variant."""
75+
kwargs = {}
76+
if dtype_override is not None:
77+
kwargs["torch_dtype"] = dtype_override
78+
self.processor = AutoProcessor.from_pretrained(
79+
self._variant_config.pretrained_model_name, **kwargs
80+
)
81+
return self.processor
82+
5883
def load_model(self, dtype_override=None):
5984
"""Load and return the Mistral Pixtral model instance with default settings.
6085
@@ -66,21 +91,22 @@ def load_model(self, dtype_override=None):
6691
torch.nn.Module: The Mistral Pixtral model instance.
6792
6893
"""
69-
# self.processor = AutoProcessor.from_pretrained(self.model_name)
94+
pretrained_model_name = self._variant_config.pretrained_model_name
95+
if self.processor is None:
96+
self._load_processor(dtype_override)
7097

71-
# Load pre-trained model from HuggingFace
7298
model_kwargs = {}
7399
if dtype_override is not None:
74100
model_kwargs["torch_dtype"] = dtype_override
75101

76102
model = LlavaForConditionalGeneration.from_pretrained(
77-
self.model_name, **model_kwargs
103+
pretrained_model_name, **model_kwargs
78104
)
79105
self.model = model
80106
self.config = model.config
81107
return model
82108

83-
def load_inputs(self, batch_size=1):
109+
def load_inputs(self, batch_size=1, dtype_override=None):
84110
"""Load and return sample inputs for the Mistral Pixtral model with default settings.
85111
86112
Args:
@@ -89,26 +115,33 @@ def load_inputs(self, batch_size=1):
89115
Returns:
90116
dict: Input tensors that can be fed to the model.
91117
"""
118+
if self.processor is None:
119+
self._load_processor(dtype_override)
120+
url_dog = "https://picsum.photos/id/237/200/300"
121+
url_mountain = "https://picsum.photos/seed/picsum/200/300"
122+
chat = [
123+
{
124+
"role": "user",
125+
"content": [
126+
{"type": "text", "content": "Can this animal"},
127+
{"type": "image"},
128+
{"type": "text", "content": "live here?"},
129+
{"type": "image"},
130+
],
131+
}
132+
]
133+
prompt = self.processor.apply_chat_template(chat)
134+
inputs = self.processor(
135+
text=prompt, images=[url_dog, url_mountain], return_tensors="pt"
136+
)
92137

93-
# https://github.com/tenstorrent/tt-torch/issues/904
94-
inputs = {
95-
"input_ids": torch.tensor(
96-
[[1, 3, 12483, 1593, 11386, 10, 51883, 3226, 1063, 10, 4]],
97-
dtype=torch.long,
98-
),
99-
"attention_mask": torch.tensor(
100-
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.long
101-
),
102-
}
103-
104-
# Use repeat_interleave to expand batch dimension
105-
inputs = {
106-
"input_ids": inputs["input_ids"].repeat_interleave(batch_size, dim=0),
107-
"attention_mask": inputs["attention_mask"].repeat_interleave(
108-
batch_size, dim=0
109-
),
110-
}
138+
for key in inputs:
139+
if torch.is_tensor(inputs[key]):
140+
inputs[key] = inputs[key].repeat_interleave(batch_size, dim=0)
111141

142+
if dtype_override is not None:
143+
for key in inputs:
144+
inputs[key] = cast_input_to_type(inputs[key], dtype_override)
112145
return inputs
113146

114147
def get_mesh_config(self, num_devices: int):

0 commit comments

Comments
 (0)