Skip to content

Commit 071ed00

Browse files
committed
add codonfm 5b arch params
Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent 6209b35 commit 071ed00

File tree

5 files changed

+57
-3
lines changed

5 files changed

+57
-3
lines changed

bionemo-recipes/recipes/codonfm_ptl_te/codonfm_ckpt_te_conversion.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,39 @@
2626

2727
import argparse
2828
import logging
29+
import os
2930

3031
import torch
32+
from safetensors.torch import save_file as safetensors_save_file
3133

3234
from src.utils.load_checkpoint import load_checkpoint
3335

3436

3537
logger = logging.getLogger(__name__)
3638

39+
ALLOWED_HYPERPARAMETER_KEYS = (
40+
"vocab_size",
41+
"hidden_size",
42+
"num_hidden_layers",
43+
"num_attention_heads",
44+
"intermediate_size",
45+
"hidden_act",
46+
"hidden_dropout_prob",
47+
"attention_probs_dropout_prob",
48+
"initializer_range",
49+
"layer_norm_eps",
50+
"pad_token_id",
51+
"position_embedding_type",
52+
"classifier_dropout",
53+
"rotary_theta",
54+
"ignore_index",
55+
"loss_type",
56+
"lora",
57+
"lora_alpha",
58+
"lora_r",
59+
"lora_dropout",
60+
)
61+
3762
# PYTorch -> TE keymap
3863
PYTORCH_TO_TE_KEYMAP = {
3964
"model.layers.*.pre_attn_layer_norm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
@@ -300,6 +325,11 @@ def convert_state_dict(src: dict, keymap: dict):
300325
return dst_state_dict
301326

302327

328+
def filter_hyper_parameters(hyper_parameters: dict) -> dict:
329+
"""Keep only conversion-compatible hyperparameter keys."""
330+
return {key: value for key, value in hyper_parameters.items() if key in ALLOWED_HYPERPARAMETER_KEYS}
331+
332+
303333
def main():
304334
"""Main function."""
305335
logging.basicConfig(level=logging.INFO)
@@ -325,6 +355,7 @@ def main():
325355
# Load source checkpoint (automatically detects format)
326356
logger.info(f"Loading checkpoint from {args.src}")
327357
src_checkpoint = load_checkpoint(args.src, map_location="cpu")
358+
src_checkpoint["hyper_parameters"] = filter_hyper_parameters(src_checkpoint["hyper_parameters"])
328359

329360
# Perform conversion based on direction
330361
if args.direction == "pytorch2te":
@@ -341,11 +372,19 @@ def main():
341372
dst_state_dict = split_qkv(converted_state_dict, src_checkpoint["hyper_parameters"])
342373

343374
# Prepare final checkpoint
344-
dst_checkpoint = {"state_dict": dst_state_dict, "hyper_parameters": src_checkpoint["hyper_parameters"]}
375+
dst_checkpoint = {
376+
"state_dict": dst_state_dict,
377+
"hyper_parameters": src_checkpoint["hyper_parameters"],
378+
}
345379

346380
# Save the converted checkpoint in pickled format
347381
torch.save(dst_checkpoint, args.dst)
348-
logger.info(f"Successfully converted checkpoint from {args.src} to {args.dst}")
382+
logger.info(f"Successfully converted checkpoint saved to {args.dst}")
383+
384+
# Save the state_dict in safetensors format alongside the .ckpt file
385+
safetensors_path = os.path.splitext(args.dst)[0] + ".safetensors"
386+
safetensors_save_file(dst_state_dict, safetensors_path)
387+
logger.info(f"Successfully saved safetensors checkpoint to {safetensors_path}")
349388

350389

351390
if __name__ == "__main__":

bionemo-recipes/recipes/codonfm_ptl_te/data_scripts/check_codon_frequency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tqdm import tqdm
2424

2525

26-
sys.path.append("/workspace/codon_fm")
26+
sys.path.append("/workspace/codonfm")
2727
from src.tokenizer import Tokenizer
2828

2929

bionemo-recipes/recipes/codonfm_ptl_te/notebooks/4-EnCodon-Downstream-Task-riboNN.ipynb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@
108108
")\n",
109109
"download_checkpoint(\n",
110110
" repo_id=\"nvidia/NV-CodonFM-Encodon-TE-1B-v1\", local_dir=\"/data/checkpoints/NV-CodonFM-Encodon-TE-1B-v1\"\n",
111+
")\n",
112+
"download_checkpoint(\n",
113+
" repo_id=\"nvidia/NV-CodonFM-Encodon-TE-5B-v1\", local_dir=\"/data/checkpoints/NV-CodonFM-Encodon-TE-5B-v1\"\n",
114+
")\n",
115+
"download_checkpoint(\n",
116+
" repo_id=\"nvidia/NV-CodonFM-Encodon-TE-Cdwt-5B-v1\", local_dir=\"/data/checkpoints/NV-CodonFM-Encodon-TE-Cdwt-5B-v1\"\n",
111117
")"
112118
]
113119
},
@@ -123,6 +129,8 @@
123129
" \"/data/checkpoints/NV-CodonFM-Encodon-TE-80M-v1\",\n",
124130
" \"/data/checkpoints/NV-CodonFM-Encodon-TE-600M-v1\",\n",
125131
" \"/data/checkpoints/NV-CodonFM-Encodon-TE-Cdwt-1B-v1\",\n",
132+
" \"/data/checkpoints/NV-CodonFM-Encodon-TE-5B-v1\",\n",
133+
" \"/data/checkpoints/NV-CodonFM-Encodon-TE-Cdwt-5B-v1\",\n",
126134
"]\n",
127135
"\n",
128136
"checkpoint_path = checkpoint_paths[0]\n",

bionemo-recipes/recipes/codonfm_ptl_te/src/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ def get_logger_config(args: Any) -> fdl.Config:
251251
"num_attention_heads": 16,
252252
"num_hidden_layers": 18,
253253
},
254+
"encodon_5b": {
255+
"hidden_size": 4096,
256+
"intermediate_size": 16384,
257+
"num_attention_heads": 32,
258+
"num_hidden_layers": 24,
259+
},
254260
"encodon_10b": {
255261
"hidden_size": 5120,
256262
"intermediate_size": 20480,

bionemo-recipes/recipes/codonfm_ptl_te/src/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def get_parser(): # noqa: D103
113113
"encodon_80m",
114114
"encodon_600m",
115115
"encodon_1b",
116+
"encodon_5b",
116117
"encodon_10b",
117118
],
118119
)

0 commit comments

Comments
 (0)