44"""
55Mistral Pixtral model loader implementation
66"""
7-
8-
97import torch
10- from transformers import LlavaForConditionalGeneration # , AutoProcessor
8+ from transformers import LlavaForConditionalGeneration , AutoProcessor
119from ....config import (
1210 ModelInfo ,
1311 ModelGroup ,
1412 ModelTask ,
1513 ModelSource ,
1614 Framework ,
15+ StrEnum ,
16+ ModelConfig ,
1717)
18+ from typing import Optional
1819from ....base import ForgeModel
20+ from ....tools .utils import cast_input_to_type
21+
22+
23+ class ModelVariant (StrEnum ):
24+ """Available Pixtral model variants."""
25+
26+ PIXTRAL_12B = "pixtral-12b"
1927
2028
2129class ModelLoader (ForgeModel ):
2230 """Pixtral model loader implementation."""
2331
24- def __init__ (self , variant = None ):
32+ # Dictionary of available model variants
33+ _VARIANTS = {
34+ ModelVariant .PIXTRAL_12B : ModelConfig (
35+ pretrained_model_name = "mistral-community/pixtral-12b" ,
36+ ),
37+ }
38+
39+ # Default variant to use
40+ DEFAULT_VARIANT = ModelVariant .PIXTRAL_12B
41+
42+ def __init__ (self , variant : Optional [ModelVariant ] = None ):
2543 """Initialize ModelLoader with specified variant.
2644
2745 Args:
2846 variant: Optional string specifying which variant to use.
2947 If None, DEFAULT_VARIANT is used.
3048 """
3149 super ().__init__ (variant )
32-
33- # Configuration parameters
34- self .model_name = "mistral-community/pixtral-12b"
35- # self.processor = None
50+ self .processor = None
51+ self .model = None
52+ self .config = None
3653
3754 @classmethod
38- def _get_model_info (cls , variant_name : str = None ):
55+ def _get_model_info (cls , variant : Optional [ ModelVariant ] = None ) -> ModelInfo :
3956 """Get model information for dashboard and metrics reporting.
4057
4158 Args:
@@ -44,17 +61,25 @@ def _get_model_info(cls, variant_name: str = None):
4461 Returns:
4562 ModelInfo: Information about the model and variant
4663 """
47- if variant_name is None :
48- variant_name = "base"
4964 return ModelInfo (
50- model = "mistral-community/pixtral-12b " ,
51- variant = variant_name ,
65+ model = "mistral-community" ,
66+ variant = variant ,
5267 group = ModelGroup .RED ,
5368 task = ModelTask .MM_VISUAL_QA ,
5469 source = ModelSource .HUGGING_FACE ,
5570 framework = Framework .TORCH ,
5671 )
5772
73+ def _load_processor (self , dtype_override = None ):
74+ """Load processor for the current variant."""
75+ kwargs = {}
76+ if dtype_override is not None :
77+ kwargs ["torch_dtype" ] = dtype_override
78+ self .processor = AutoProcessor .from_pretrained (
79+ self ._variant_config .pretrained_model_name , ** kwargs
80+ )
81+ return self .processor
82+
5883 def load_model (self , dtype_override = None ):
5984 """Load and return the Mistral Pixtral model instance with default settings.
6085
@@ -66,21 +91,22 @@ def load_model(self, dtype_override=None):
6691 torch.nn.Module: The Mistral Pixtral model instance.
6792
6893 """
69- # self.processor = AutoProcessor.from_pretrained(self.model_name)
94+ pretrained_model_name = self ._variant_config .pretrained_model_name
95+ if self .processor is None :
96+ self ._load_processor (dtype_override )
7097
71- # Load pre-trained model from HuggingFace
7298 model_kwargs = {}
7399 if dtype_override is not None :
74100 model_kwargs ["torch_dtype" ] = dtype_override
75101
76102 model = LlavaForConditionalGeneration .from_pretrained (
77- self . model_name , ** model_kwargs
103+ pretrained_model_name , ** model_kwargs
78104 )
79105 self .model = model
80106 self .config = model .config
81107 return model
82108
83- def load_inputs (self , batch_size = 1 ):
109+ def load_inputs (self , batch_size = 1 , dtype_override = None ):
84110 """Load and return sample inputs for the Mistral Pixtral model with default settings.
85111
86112 Args:
@@ -89,26 +115,33 @@ def load_inputs(self, batch_size=1):
89115 Returns:
90116 dict: Input tensors that can be fed to the model.
91117 """
118+ if self .processor is None :
119+ self ._load_processor (dtype_override )
120+ url_dog = "https://picsum.photos/id/237/200/300"
121+ url_mountain = "https://picsum.photos/seed/picsum/200/300"
122+ chat = [
123+ {
124+ "role" : "user" ,
125+ "content" : [
126+ {"type" : "text" , "content" : "Can this animal" },
127+ {"type" : "image" },
128+ {"type" : "text" , "content" : "live here?" },
129+ {"type" : "image" },
130+ ],
131+ }
132+ ]
133+ prompt = self .processor .apply_chat_template (chat )
134+ inputs = self .processor (
135+ text = prompt , images = [url_dog , url_mountain ], return_tensors = "pt"
136+ )
92137
93- # https://github.com/tenstorrent/tt-torch/issues/904
94- inputs = {
95- "input_ids" : torch .tensor (
96- [[1 , 3 , 12483 , 1593 , 11386 , 10 , 51883 , 3226 , 1063 , 10 , 4 ]],
97- dtype = torch .long ,
98- ),
99- "attention_mask" : torch .tensor (
100- [[1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]], dtype = torch .long
101- ),
102- }
103-
104- # Use repeat_interleave to expand batch dimension
105- inputs = {
106- "input_ids" : inputs ["input_ids" ].repeat_interleave (batch_size , dim = 0 ),
107- "attention_mask" : inputs ["attention_mask" ].repeat_interleave (
108- batch_size , dim = 0
109- ),
110- }
138+ for key in inputs :
139+ if torch .is_tensor (inputs [key ]):
140+ inputs [key ] = inputs [key ].repeat_interleave (batch_size , dim = 0 )
111141
142+ if dtype_override is not None :
143+ for key in inputs :
144+ inputs [key ] = cast_input_to_type (inputs [key ], dtype_override )
112145 return inputs
113146
114147 def get_mesh_config (self , num_devices : int ):
0 commit comments