Skip to content

Commit 4986482

Browse files
committed
addressing coderabbit review
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 8cbb55f commit 4986482

File tree

14 files changed

+187
-92
lines changed

14 files changed

+187
-92
lines changed

bionemo-recipes/models/esm2/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ For FP4 (NVFP4) quantization, pass an `fp4_recipe` instead and set the correspon
127127
`"fp4"` in `layer_precision`:
128128

129129
```python
130-
fp4_recipe = te_recipe.NVFP4()
130+
fp4_recipe = te_recipe.NVFP4BlockScaling()
131131

132132
config = NVEsmConfig.from_pretrained(
133133
"nvidia/esm2_t6_8M_UR50D",
@@ -151,7 +151,7 @@ config = NVEsmConfig.from_pretrained(
151151
layer_precision=["fp4"] * 6,
152152
use_quantized_model_init=True,
153153
)
154-
model = NVEsmForMaskedLM(config, fp4_recipe=te_recipe.NVFP4())
154+
model = NVEsmForMaskedLM(config, fp4_recipe=te_recipe.NVFP4BlockScaling())
155155
```
156156

157157
### Notes

bionemo-recipes/models/esm2/modeling_esm_te.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ def __init__(
132132
)
133133

134134
if layer_precision is not None:
135-
assert len(layer_precision) == self.num_hidden_layers, (
136-
f"layer_precision must be a list of length {self.num_hidden_layers}"
137-
)
135+
if len(layer_precision) != self.num_hidden_layers:
136+
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
138137
for precision in layer_precision:
139-
assert precision in {"fp8", "fp4", None}, 'layer_precision element must be "fp8", "fp4", or None'
138+
if precision not in {"fp8", "fp4", None}:
139+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
140140

141141

142142
class NVEsmEncoder(nn.Module):
@@ -160,12 +160,20 @@ def __init__(
160160
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
161161
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
162162

163-
if fp8_recipe is not None and self.config.layer_precision is None:
164-
if fp4_recipe is not None:
163+
if self.config.layer_precision is None:
164+
if fp8_recipe is not None and fp4_recipe is not None:
165165
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
166-
167-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
168-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
166+
if fp8_recipe is not None:
167+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
168+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
169+
elif fp4_recipe is not None:
170+
raise RuntimeError(
171+
"FP4 recipe provided but no layer_precision configured. "
172+
"Set layer_precision explicitly when using FP4."
173+
)
174+
175+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
176+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
169177

170178
def _init_method(x):
171179
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
@@ -234,7 +242,7 @@ def forward(
234242
with torch.autocast(device_type="cuda", enabled=False):
235243
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
236244
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
237-
if te_rope_emb.dtype == torch.float32:
245+
if te_rope_emb.dtype != torch.float32:
238246
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
239247

240248
with self.get_autocast_context(None, outer=True):
@@ -295,6 +303,8 @@ def get_autocast_context(
295303
recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision)
296304

297305
if init and self.config.use_quantized_model_init:
306+
if precision == "fp4" and recipe is None:
307+
raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
298308
if precision in ("fp8", "fp4"):
299309
return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
300310
return nullcontext()

bionemo-recipes/models/llama3/modeling_llama_te.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ def __init__(
7373
self.use_quantized_model_init = use_quantized_model_init
7474

7575
if layer_precision is not None:
76-
assert len(layer_precision) == self.num_hidden_layers, (
77-
f"layer_precision must be a list of length {self.num_hidden_layers}"
78-
)
76+
if len(layer_precision) != self.num_hidden_layers:
77+
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
7978
for precision in layer_precision:
80-
assert precision in {"fp8", "fp4", None}, 'layer_precision element must be "fp8", "fp4", or None'
79+
if precision not in {"fp8", "fp4", None}:
80+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
8181

8282

8383
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -157,12 +157,20 @@ def __init__(
157157
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
158158
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
159159

160-
if fp8_recipe is not None and self.config.layer_precision is None:
161-
if fp4_recipe is not None:
160+
if self.config.layer_precision is None:
161+
if fp8_recipe is not None and fp4_recipe is not None:
162162
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
163-
164-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
165-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
163+
if fp8_recipe is not None:
164+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
165+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
166+
elif fp4_recipe is not None:
167+
raise RuntimeError(
168+
"FP4 recipe provided but no layer_precision configured. "
169+
"Set layer_precision explicitly when using FP4."
170+
)
171+
172+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
173+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
166174

167175
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
168176

@@ -287,7 +295,7 @@ def forward(
287295
# Ensure that rotary embeddings are computed with at a higher precision
288296
with torch.autocast(device_type="cuda", enabled=False):
289297
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)
290-
if te_rope_emb.dtype == torch.float32:
298+
if te_rope_emb.dtype != torch.float32:
291299
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
292300

293301
with self.get_autocast_context(None, outer=True):

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn as nn
2525
import transformer_engine.common.recipe
2626
import transformer_engine.pytorch
27+
import transformers
2728
from transformer_engine.pytorch.attention import InferenceParams
2829
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
2930
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -69,11 +70,11 @@ def __init__(
6970
self.use_quantized_model_init = use_quantized_model_init
7071

7172
if layer_precision is not None:
72-
assert len(layer_precision) == self.num_hidden_layers, (
73-
f"layer_precision must be a list of length {self.num_hidden_layers}"
74-
)
73+
if len(layer_precision) != self.num_hidden_layers:
74+
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
7575
for precision in layer_precision:
76-
assert precision in {"fp8", "fp4", None}, 'layer_precision element must be "fp8", "fp4", or None'
76+
if precision not in {"fp8", "fp4", None}:
77+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
7778

7879

7980
class NVMixtralPreTrainedModel(PreTrainedModel):
@@ -486,7 +487,7 @@ def get_autocast_context(
486487
return transformer_engine.pytorch.autocast(enabled=False)
487488

488489

489-
class NVMixtralForCausalLM(NVMixtralPreTrainedModel, __import__("transformers").GenerationMixin):
490+
class NVMixtralForCausalLM(NVMixtralPreTrainedModel, transformers.GenerationMixin):
490491
"""Mixtral model with causal language head."""
491492

492493
_tied_weights_keys: ClassVar[list[str]] = []

bionemo-recipes/models/qwen/modeling_qwen2_te.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def __init__(
7070
self.use_quantized_model_init = use_quantized_model_init
7171

7272
if layer_precision is not None:
73-
assert len(layer_precision) == self.num_hidden_layers, (
74-
f"layer_precision must be a list of length {self.num_hidden_layers}"
75-
)
73+
if len(layer_precision) != self.num_hidden_layers:
74+
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
7675
for precision in layer_precision:
77-
assert precision in {"fp8", "fp4", None}, 'layer_precision element must be "fp8", "fp4", or None'
76+
if precision not in {"fp8", "fp4", None}:
77+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
7878

7979

8080
class NVQwen2PreTrainedModel(PreTrainedModel):
@@ -154,12 +154,20 @@ def __init__(
154154
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
155155
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
156156

157-
if fp8_recipe is not None and self.config.layer_precision is None:
158-
if fp4_recipe is not None:
157+
if self.config.layer_precision is None:
158+
if fp8_recipe is not None and fp4_recipe is not None:
159159
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
160-
161-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
162-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
160+
if fp8_recipe is not None:
161+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
162+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
163+
elif fp4_recipe is not None:
164+
raise RuntimeError(
165+
"FP4 recipe provided but no layer_precision configured. "
166+
"Set layer_precision explicitly when using FP4."
167+
)
168+
169+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
170+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
163171

164172
head_dim = config.hidden_size // config.num_attention_heads
165173

@@ -290,6 +298,8 @@ def forward(
290298
# Ensure that rotary embeddings are computed with at a higher precision
291299
with torch.autocast(device_type="cuda", enabled=False):
292300
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)
301+
if te_rope_emb.dtype != torch.float32:
302+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
293303

294304
with self.get_autocast_context(None, outer=True):
295305
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):

bionemo-recipes/models/qwen/modeling_qwen3_te.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def __init__(
7070
self.use_quantized_model_init = use_quantized_model_init
7171

7272
if layer_precision is not None:
73-
assert len(layer_precision) == self.num_hidden_layers, (
74-
f"layer_precision must be a list of length {self.num_hidden_layers}"
75-
)
73+
if len(layer_precision) != self.num_hidden_layers:
74+
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
7675
for precision in layer_precision:
77-
assert precision in {"fp8", "fp4", None}, 'layer_precision element must be "fp8", "fp4", or None'
76+
if precision not in {"fp8", "fp4", None}:
77+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
7878

7979

8080
class NVQwen3PreTrainedModel(PreTrainedModel):
@@ -154,12 +154,20 @@ def __init__(
154154
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
155155
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
156156

157-
if fp8_recipe is not None and self.config.layer_precision is None:
158-
if fp4_recipe is not None:
157+
if self.config.layer_precision is None:
158+
if fp8_recipe is not None and fp4_recipe is not None:
159159
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
160-
161-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
162-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
160+
if fp8_recipe is not None:
161+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
162+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
163+
elif fp4_recipe is not None:
164+
raise RuntimeError(
165+
"FP4 recipe provided but no layer_precision configured. "
166+
"Set layer_precision explicitly when using FP4."
167+
)
168+
169+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
170+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
163171

164172
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
165173

@@ -300,6 +308,8 @@ def forward(
300308
# Ensure that rotary embeddings are computed with at a higher precision
301309
with torch.autocast(device_type="cuda", enabled=False):
302310
te_rope_emb = self.rotary_emb(max_seq_len=self.config.max_position_embeddings)
311+
if te_rope_emb.dtype != torch.float32:
312+
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
303313

304314
with self.get_autocast_context(None, outer=True):
305315
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):

bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ def __init__(
132132
)
133133

134134
if layer_precision is not None:
135-
assert len(layer_precision) == self.num_hidden_layers, (
136-
f"layer_precision must be a list of length {self.num_hidden_layers}"
137-
)
135+
if len(layer_precision) != self.num_hidden_layers:
136+
raise ValueError(f"layer_precision must be a list of length {self.num_hidden_layers}")
138137
for precision in layer_precision:
139-
assert precision in {"fp8", "fp4", None}, 'layer_precision element must be "fp8", "fp4", or None'
138+
if precision not in {"fp8", "fp4", None}:
139+
raise ValueError(f'layer_precision element must be "fp8", "fp4", or None, got {precision!r}')
140140

141141

142142
class NVEsmEncoder(nn.Module):
@@ -160,12 +160,20 @@ def __init__(
160160
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = fp8_recipe
161161
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = fp4_recipe
162162

163-
if fp8_recipe is not None and self.config.layer_precision is None:
164-
if fp4_recipe is not None:
163+
if self.config.layer_precision is None:
164+
if fp8_recipe is not None and fp4_recipe is not None:
165165
raise RuntimeError("Both FP8 and FP4 recipes provided, but no layer precision provided.")
166-
167-
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
168-
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
166+
if fp8_recipe is not None:
167+
warnings.warn("No layer precision provided, using FP8 recipe for all layers.", UserWarning)
168+
self.config.layer_precision = ["fp8"] * self.config.num_hidden_layers
169+
elif fp4_recipe is not None:
170+
raise RuntimeError(
171+
"FP4 recipe provided but no layer_precision configured. "
172+
"Set layer_precision explicitly when using FP4."
173+
)
174+
175+
if self.config.layer_precision is not None and "fp4" in self.config.layer_precision and fp4_recipe is None:
176+
raise RuntimeError("layer_precision contains 'fp4' entries but no fp4_recipe was provided.")
169177

170178
def _init_method(x):
171179
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
@@ -234,7 +242,7 @@ def forward(
234242
with torch.autocast(device_type="cuda", enabled=False):
235243
te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings)
236244
te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True)
237-
if te_rope_emb.dtype == torch.float32:
245+
if te_rope_emb.dtype != torch.float32:
238246
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
239247

240248
with self.get_autocast_context(None, outer=True):
@@ -295,6 +303,8 @@ def get_autocast_context(
295303
recipe = {"fp8": self._fp8_recipe, "fp4": self._fp4_recipe}.get(precision)
296304

297305
if init and self.config.use_quantized_model_init:
306+
if precision == "fp4" and recipe is None:
307+
raise RuntimeError("No FP4 recipe provided, but layer precision is set to FP4.")
298308
if precision in ("fp8", "fp4"):
299309
return transformer_engine.pytorch.quantized_model_init(recipe=recipe)
300310
return nullcontext()

0 commit comments

Comments
 (0)