33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
55import re
6- from typing import Optional
6+ from typing import Optional , Any , Dict
77
88from pydantic import BaseModel , Field
99
1717 QUANT_MAPPING ,
1818 QUANT_METHODS ,
1919 RUNTIME_WEIGHTS ,
20+ EXCLUDED_MODELS
2021)
2122from ads .common .utils import parse_bool
2223
2324
2425class GeneralConfig (BaseModel ):
2526 num_hidden_layers : int = Field (
2627 ...,
27- description = "Number of transformer blocks (layers) in the model’ s neural network stack." ,
28+ description = "Number of transformer blocks (layers) in the model' s neural network stack." ,
2829 )
2930 hidden_size : int = Field (
3031 ..., description = "Embedding dimension or hidden size of each layer."
@@ -46,6 +47,27 @@ class GeneralConfig(BaseModel):
4647 description = "Parameter data type: 'float32', 'float16', etc." ,
4748 )
4849
50+ @staticmethod
51+ def _get_required_int (raw : dict [str , Any ], keys : list [str ], field_name : str ) -> int :
52+ """
53+ Helper to safely extract a required integer field from multiple possible keys.
54+ Raises AquaRecommendationError if the value is missing or None.
55+ """
56+ for key in keys :
57+ val = raw .get (key )
58+ if val is not None :
59+ try :
60+ return int (val )
61+ except (ValueError , TypeError ):
62+ pass # If value exists but isn't a number, keep looking or fail later
63+
64+ # If we reach here, no valid key was found
65+ raise AquaRecommendationError (
66+ f"Could not determine '{ field_name } ' from the model configuration. "
67+ f"Checked keys: { keys } . "
68+ "This indicates the model architecture might not be supported or uses a non-standard config structure."
69+ )
70+
4971 @classmethod
5072 def get_weight_dtype (cls , raw : dict ) -> str :
5173 # some configs use a different weight dtype at runtime
@@ -173,21 +195,26 @@ class VisionConfig(GeneralConfig):
173195 @classmethod
174196 def from_raw_config (cls , vision_section : dict ) -> "VisionConfig" :
175197 weight_dtype = cls .get_weight_dtype (vision_section )
176- num_layers = (
177- vision_section . get ( "num_layers" )
178- or vision_section . get ( "vision_layers" )
179- or vision_section . get ( " num_hidden_layers")
180- or vision_section . get ( "n_layer" )
198+
199+ num_layers = cls . _get_required_int (
200+ vision_section ,
201+ [ "num_layers" , "vision_layers" , " num_hidden_layers", "n_layer" ],
202+ "num_hidden_layers"
181203 )
182204
183- hidden_size = vision_section .get ("hidden_size" ) or vision_section .get (
184- "embed_dim"
205+ hidden_size = cls ._get_required_int (
206+ vision_section ,
207+ ["hidden_size" , "embed_dim" ],
208+ "hidden_size"
185209 )
186210
187- mlp_dim = vision_section .get ("mlp_dim" ) or vision_section .get (
188- "intermediate_size"
211+ mlp_dim = cls ._get_required_int (
212+ vision_section ,
213+ ["mlp_dim" , "intermediate_size" ],
214+ "mlp_dim"
189215 )
190216
217+ # Optional fields can use standard .get()
191218 num_attention_heads = (
192219 vision_section .get ("num_attention_heads" )
193220 or vision_section .get ("vision_num_attention_heads" )
@@ -202,10 +229,10 @@ def from_raw_config(cls, vision_section: dict) -> "VisionConfig":
202229 weight_dtype = str (cls .get_weight_dtype (vision_section ))
203230
204231 return cls (
205- num_hidden_layers = int ( num_layers ) ,
206- hidden_size = int ( hidden_size ) ,
207- mlp_dim = int ( mlp_dim ) ,
208- patch_size = int (patch_size ),
232+ num_hidden_layers = num_layers ,
233+ hidden_size = hidden_size ,
234+ mlp_dim = mlp_dim ,
235+ patch_size = int (patch_size ) if patch_size else 0 ,
209236 num_attention_heads = int (num_attention_heads )
210237 if num_attention_heads
211238 else None ,
@@ -311,18 +338,28 @@ def optimal_config(self):
311338 return configs
312339
313340 @classmethod
314- def validate_model_support (cls , raw : dict ) -> ValueError :
341+ def validate_model_support (cls , raw : dict ):
315342 """
316343 Validates if model is decoder-only. Check for text-generation model occurs at DataScienceModel level.
344+ Also explicitly checks for unsupported audio/speech models.
317345 """
318- excluded_models = {"t5" , "gemma" , "bart" , "bert" , "roberta" , "albert" }
346+ # Known unsupported model architectures or types
347+ excluded_models = EXCLUDED_MODELS
348+
349+ model_type = raw .get ("model_type" , "" ).lower ()
350+
351+ if model_type in excluded_models :
352+ raise AquaRecommendationError (
353+ f"The model type '{ model_type } ' is not supported. "
354+ "Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc). "
355+ "Encoder-decoder models (ex. T5, Gemma), encoder-only (BERT), and audio models (Whisper) are not supported at this time."
356+ )
357+
319358 if (
320359 raw .get ("is_encoder_decoder" , False ) # exclude encoder-decoder models
321360 or (
322361 raw .get ("is_decoder" ) is False
323362 ) # exclude explicit encoder-only models (altho no text-generation task ones, just dbl check)
324- or raw .get ("model_type" , "" ).lower () # exclude by known model types
325- in excluded_models
326363 ):
327364 raise AquaRecommendationError (
328365 "Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc). "
@@ -337,14 +374,33 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
337374 """
338375 cls .validate_model_support (raw )
339376
340- # Field mappings with fallback
341- num_hidden_layers = (
342- raw .get ("num_hidden_layers" ) or raw .get ("n_layer" ) or raw .get ("num_layers" )
377+ # Field mappings with fallback using safe extraction
378+ num_hidden_layers = cls ._get_required_int (
379+ raw ,
380+ ["num_hidden_layers" , "n_layer" , "num_layers" ],
381+ "num_hidden_layers"
343382 )
344- weight_dtype = cls .get_weight_dtype (raw )
345383
346- hidden_size = raw .get ("hidden_size" ) or raw .get ("n_embd" ) or raw .get ("d_model" )
347- vocab_size = raw .get ("vocab_size" )
384+ hidden_size = cls ._get_required_int (
385+ raw ,
386+ ["hidden_size" , "n_embd" , "d_model" ],
387+ "hidden_size"
388+ )
389+
390+ num_attention_heads = cls ._get_required_int (
391+ raw ,
392+ ["num_attention_heads" , "n_head" , "num_heads" ],
393+ "num_attention_heads"
394+ )
395+
396+ # Vocab size might be missing in some architectures, but usually required for memory calc
397+ vocab_size = cls ._get_required_int (
398+ raw ,
399+ ["vocab_size" ],
400+ "vocab_size"
401+ )
402+
403+ weight_dtype = cls .get_weight_dtype (raw )
348404 quantization = cls .detect_quantization_bits (raw )
349405 quantization_type = cls .detect_quantization_type (raw )
350406
@@ -355,15 +411,18 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
355411 raw .get ("num_key_value_heads" ) # GQA models (ex. Llama-type)
356412 )
357413
358- num_attention_heads = (
359- raw .get ("num_attention_heads" ) or raw .get ("n_head" ) or raw .get ("num_heads" )
360- )
361-
362414 head_dim = raw .get ("head_dim" ) or (
363415 int (hidden_size ) // int (num_attention_heads )
364416 if hidden_size and num_attention_heads
365417 else None
366418 )
419+
420+ # Ensure head_dim is not None if calculation failed
421+ if head_dim is None :
422+ raise AquaRecommendationError (
423+ "Could not determine 'head_dim' and it could not be calculated from 'hidden_size' and 'num_attention_heads'."
424+ )
425+
367426 max_seq_len = (
368427 raw .get ("max_position_embeddings" )
369428 or raw .get ("n_positions" )
@@ -388,12 +447,12 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
388447 ) # trust-remote-code is always needed when this key is present
389448
390449 return cls (
391- num_hidden_layers = int ( num_hidden_layers ) ,
392- hidden_size = int ( hidden_size ) ,
393- num_attention_heads = int ( num_attention_heads ) ,
450+ num_hidden_layers = num_hidden_layers ,
451+ hidden_size = hidden_size ,
452+ num_attention_heads = num_attention_heads ,
394453 num_key_value_heads = num_key_value_heads ,
395454 head_dim = int (head_dim ),
396- vocab_size = int ( vocab_size ) ,
455+ vocab_size = vocab_size ,
397456 weight_dtype = weight_dtype ,
398457 quantization = quantization ,
399458 quantization_type = quantization_type ,
@@ -511,4 +570,4 @@ def get_model_config(cls, raw: dict):
511570 # Neither found -- explicit failure
512571 raise AquaRecommendationError (
513572 "Config could not be parsed as either text, vision, or multimodal model. Check your fields/structure."
514- )
573+ )
0 commit comments