@@ -132,11 +132,11 @@ def __init__(
132132 )
133133
134134 if layer_precision is not None :
135- assert len (layer_precision ) == self .num_hidden_layers , (
136- f"layer_precision must be a list of length { self .num_hidden_layers } "
137- )
135+ if len (layer_precision ) != self .num_hidden_layers :
136+ raise ValueError (f"layer_precision must be a list of length { self .num_hidden_layers } " )
138137 for precision in layer_precision :
139- assert precision in {"fp8" , "fp4" , None }, 'layer_precision element must be "fp8", "fp4", or None'
138+ if precision not in {"fp8" , "fp4" , None }:
139+ raise ValueError (f'layer_precision element must be "fp8", "fp4", or None, got { precision !r} ' )
140140
141141
142142class NVEsmEncoder (nn .Module ):
@@ -160,12 +160,20 @@ def __init__(
160160 self ._fp8_recipe : transformer_engine .common .recipe .Recipe | None = fp8_recipe
161161 self ._fp4_recipe : transformer_engine .common .recipe .Recipe | None = fp4_recipe
162162
163- if fp8_recipe is not None and self .config .layer_precision is None :
164- if fp4_recipe is not None :
163+ if self .config .layer_precision is None :
164+ if fp8_recipe is not None and fp4_recipe is not None :
165165 raise RuntimeError ("Both FP8 and FP4 recipes provided, but no layer precision provided." )
166-
167- warnings .warn ("No layer precision provided, using FP8 recipe for all layers." , UserWarning )
168- self .config .layer_precision = ["fp8" ] * self .config .num_hidden_layers
166+ if fp8_recipe is not None :
167+ warnings .warn ("No layer precision provided, using FP8 recipe for all layers." , UserWarning )
168+ self .config .layer_precision = ["fp8" ] * self .config .num_hidden_layers
169+ elif fp4_recipe is not None :
170+ raise RuntimeError (
171+ "FP4 recipe provided but no layer_precision configured. "
172+ "Set layer_precision explicitly when using FP4."
173+ )
174+
175+ if self .config .layer_precision is not None and "fp4" in self .config .layer_precision and fp4_recipe is None :
176+ raise RuntimeError ("layer_precision contains 'fp4' entries but no fp4_recipe was provided." )
169177
170178 def _init_method (x ):
171179 torch .nn .init .normal_ (x , mean = 0.0 , std = config .initializer_range )
@@ -234,7 +242,7 @@ def forward(
234242 with torch .autocast (device_type = "cuda" , enabled = False ):
235243 te_rope_emb = self .rotary_embeddings (max_seq_len = self .config .max_position_embeddings )
236244 te_rope_emb = te_rope_emb .to (hidden_states .device , non_blocking = True )
237- if te_rope_emb .dtype = = torch .float32 :
245+ if te_rope_emb .dtype ! = torch .float32 :
238246 warnings .warn ("Rotary embeddings should be in float32 for optimal performance." , UserWarning )
239247
240248 with self .get_autocast_context (None , outer = True ):
@@ -295,6 +303,8 @@ def get_autocast_context(
295303 recipe = {"fp8" : self ._fp8_recipe , "fp4" : self ._fp4_recipe }.get (precision )
296304
297305 if init and self .config .use_quantized_model_init :
306+ if precision == "fp4" and recipe is None :
307+ raise RuntimeError ("No FP4 recipe provided, but layer precision is set to FP4." )
298308 if precision in ("fp8" , "fp4" ):
299309 return transformer_engine .pytorch .quantized_model_init (recipe = recipe )
300310 return nullcontext ()
0 commit comments