Skip to content

Commit 4709e95

Browse files
committed
Cleanup unused commented out code
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent 0cacf51 commit 4709e95

File tree

1 file changed

+0
-328
lines changed

1 file changed

+0
-328
lines changed

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Lines changed: 0 additions & 328 deletions
Original file line numberDiff line numberDiff line change
@@ -764,334 +764,6 @@ class HyenaNV1b2ModelProvider(HyenaNV1bModelProvider):
764764
# glu_linear_offset: float = 1.0
765765

766766

767-
# FIXME use the following as a starting point for the new megatron bridge style model importer/exporter.
768-
# @io.model_importer(HyenaModel, "pytorch")
769-
# class PyTorchHyenaImporter(io.ModelConnector["HyenaModel", HyenaModel]):
770-
# """Importer class for converting PyTorch Hyena models to NeMo format."""
771-
772-
# def __new__(cls, path: str, model_config=None):
773-
# """Creates a new importer instance.
774-
775-
# Args:
776-
# path: Path to the PyTorch model
777-
# model_config: Optional model configuration
778-
779-
# Returns:
780-
# PyTorchHyenaImporter instance
781-
# """
782-
# instance = super().__new__(cls, path)
783-
# instance.model_config = model_config
784-
# return instance
785-
786-
# def init(self) -> HyenaModel:
787-
# """Initializes a new HyenaModel instance.
788-
789-
# Returns:
790-
# HyenaModel: Initialized model
791-
# """
792-
# return HyenaModel(self.config, tokenizer=self.tokenizer)
793-
794-
# def get_source_model(self):
795-
# """Returns the source model."""
796-
# return torch.load(str(self), map_location="cpu")
797-
798-
# def apply(self, output_path: Path, checkpoint_format: str = "torch_dist") -> Path:
799-
# """Applies the model conversion from PyTorch to NeMo format.
800-
801-
# Args:
802-
# output_path: Path to save the converted model
803-
# checkpoint_format: Format for saving checkpoints
804-
805-
# Returns:
806-
# Path: Path to the saved NeMo model
807-
# """
808-
# source = self.get_source_model()
809-
810-
# if "model" in source:
811-
# source = source["model"]
812-
813-
# class ModelState:
814-
# """Wrapper around the source model state dictionary that also handles some weight transformations."""
815-
816-
# def __init__(self, state_dict, num_layers, fp32_suffixes):
817-
# """Wrapper around the source model state dictionary that also handles some weight transformations.
818-
819-
# Args:
820-
# state_dict: original state dictionary from the source model
821-
# num_layers: number of layers in the source model
822-
# fp32_suffixes: suffixes of the weights that should be converted to float32
823-
# """
824-
# self.num_layers = num_layers
825-
# state_dict = self.transform_source_dict(state_dict)
826-
# self._state_dict = state_dict
827-
# self.fp32_suffixes = fp32_suffixes
828-
829-
# def state_dict(self):
830-
# """Return the state dictionary."""
831-
# return self._state_dict
832-
833-
# def to(self, dtype):
834-
# """Convert the state dictionary to the target dtype."""
835-
# for k, v in self._state_dict.items():
836-
# if "_extra" not in k:
837-
# if v.dtype != dtype:
838-
# logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)")
839-
# k_suffix = k.split(".")[-1]
840-
# if k_suffix in self.fp32_suffixes:
841-
# _dtype = torch.float32
842-
# else:
843-
# _dtype = dtype
844-
# self._state_dict[k] = v.to(_dtype)
845-
846-
# def adjust_medium_filter(self, updated_data):
847-
# """Adjust the medium filter."""
848-
# from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig
849-
850-
# for k, v in updated_data.items():
851-
# if "filter.h" in k or "filter.decay" in k:
852-
# updated_data[k] = v[:, : HyenaConfig().hyena_medium_conv_len]
853-
# return updated_data
854-
855-
# def transform_source_dict(self, source):
856-
# """Transform the source state dictionary.
857-
858-
# This function works by applying some challenging layer name re-mappings and
859-
# removing extra keys, as well as truncating a filter that didn't need to extend to the full
860-
# sequence length dim.
861-
# """
862-
# import re
863-
864-
# layer_map = {i + 2: i for i in range(self.num_layers)}
865-
# layer_map[self.num_layers + 3] = self.num_layers
866-
# updated_data = {}
867-
868-
# for key in list(source["module"].keys()):
869-
# if "_extra" in key:
870-
# source["module"].pop(key)
871-
# else:
872-
# match = re.search(r"sequential\.(\d+)", key)
873-
# if match:
874-
# original_layer_num = int(match.group(1))
875-
# if original_layer_num in layer_map:
876-
# # Create the updated key by replacing the layer number
877-
# new_key = re.sub(rf"\b{original_layer_num}\b", str(layer_map[original_layer_num]), key)
878-
# updated_data[new_key] = source["module"][key]
879-
# else:
880-
# # Keep the key unchanged if no mapping exists
881-
# updated_data[key] = source["module"][key]
882-
# else:
883-
# updated_data[key] = source["module"][key]
884-
# updated_data = self.adjust_medium_filter(updated_data)
885-
# return updated_data
886-
887-
# target = self.init()
888-
# trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format)
889-
# target.to(self.config.params_dtype)
890-
# fp32_suffixes = {n.split(".")[-1] for n, p in target.named_parameters() if p.dtype == torch.float32}
891-
# source = ModelState(source, self.config.num_layers, fp32_suffixes)
892-
# source.to(self.config.params_dtype)
893-
# self.convert_state(source, target)
894-
# self.nemo_save(output_path, trainer)
895-
896-
# logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")
897-
898-
# teardown(trainer, target)
899-
# del trainer, target
900-
901-
# return output_path
902-
903-
# def convert_state(self, source, target):
904-
# """Converts the state dictionary from source format to target format.
905-
906-
# Args:
907-
# source: Source model state
908-
# target: Target model
909-
910-
# Returns:
911-
# Result of applying state transforms
912-
# """
913-
# mapping = {}
914-
# mapping["sequential.0.word_embeddings.weight"] = "embedding.word_embeddings.weight"
915-
# mapping[f"sequential.{len(self.config.hybrid_override_pattern)}.norm.weight"] = "decoder.final_norm.weight"
916-
# te_enabled = self.config.use_te
917-
# for i, symbol in enumerate(self.config.hybrid_override_pattern):
918-
# if te_enabled:
919-
# mapping[f"sequential.{i}.pre_mlp_layernorm.weight"] = (
920-
# f"decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight"
921-
# )
922-
# else:
923-
# mapping[f"sequential.{i}.pre_mlp_layernorm.weight"] = f"decoder.layers.{i}.pre_mlp_layernorm.weight"
924-
# mapping[f"sequential.{i}.mlp.w3.weight"] = f"decoder.layers.{i}.mlp.linear_fc2.weight"
925-
926-
# if symbol != "*":
927-
# if te_enabled:
928-
# mapping[f"sequential.{i}.input_layernorm.weight"] = (
929-
# f"decoder.layers.{i}.mixer.dense_projection.layer_norm_weight"
930-
# )
931-
# else:
932-
# mapping[f"sequential.{i}.input_layernorm.weight"] = f"decoder.layers.{i}.norm.weight"
933-
934-
# mapping[f"sequential.{i}.mixer.dense_projection.weight"] = (
935-
# f"decoder.layers.{i}.mixer.dense_projection.weight"
936-
# )
937-
# mapping[f"sequential.{i}.mixer.hyena_proj_conv.short_conv_weight"] = (
938-
# f"decoder.layers.{i}.mixer.hyena_proj_conv.short_conv_weight"
939-
# )
940-
# mapping[f"sequential.{i}.mixer.dense.weight"] = f"decoder.layers.{i}.mixer.dense.weight"
941-
# mapping[f"sequential.{i}.mixer.dense.bias"] = f"decoder.layers.{i}.mixer.dense.bias"
942-
943-
# if symbol == "S":
944-
# mapping[f"sequential.{i}.mixer.mixer.short_conv.short_conv_weight"] = (
945-
# f"decoder.layers.{i}.mixer.mixer.short_conv.short_conv_weight"
946-
# )
947-
948-
# elif symbol == "D":
949-
# mapping[f"sequential.{i}.mixer.mixer.conv_bias"] = f"decoder.layers.{i}.mixer.mixer.conv_bias"
950-
# mapping[f"sequential.{i}.mixer.mixer.filter.h"] = f"decoder.layers.{i}.mixer.mixer.filter.h"
951-
# mapping[f"sequential.{i}.mixer.mixer.filter.decay"] = (
952-
# f"decoder.layers.{i}.mixer.mixer.filter.decay"
953-
# )
954-
955-
# elif symbol == "H":
956-
# mapping[f"sequential.{i}.mixer.mixer.conv_bias"] = f"decoder.layers.{i}.mixer.mixer.conv_bias"
957-
# mapping[f"sequential.{i}.mixer.mixer.filter.gamma"] = (
958-
# f"decoder.layers.{i}.mixer.mixer.filter.gamma"
959-
# )
960-
# mapping[f"sequential.{i}.mixer.mixer.filter.R"] = f"decoder.layers.{i}.mixer.mixer.filter.R"
961-
# mapping[f"sequential.{i}.mixer.mixer.filter.p"] = f"decoder.layers.{i}.mixer.mixer.filter.p"
962-
963-
# elif symbol == "*":
964-
# if te_enabled:
965-
# mapping[f"sequential.{i}.input_layernorm.weight"] = (
966-
# f"decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight"
967-
# )
968-
# else:
969-
# mapping[f"sequential.{i}.input_layernorm.weight"] = f"decoder.layers.{i}.input_layernorm.weight"
970-
971-
# mapping[f"sequential.{i}.mixer.dense_projection.weight"] = (
972-
# f"decoder.layers.{i}.self_attention.linear_qkv.weight"
973-
# )
974-
# mapping[f"sequential.{i}.mixer.dense.weight"] = f"decoder.layers.{i}.self_attention.linear_proj.weight"
975-
# mapping[f"sequential.{i}.mixer.dense.bias"] = f"decoder.layers.{i}.self_attention.linear_proj.bias"
976-
# else:
977-
# raise ValueError(f"Unknown symbol: {symbol}")
978-
979-
# return io.apply_transforms(
980-
# source,
981-
# target,
982-
# mapping=mapping,
983-
# transforms=[
984-
# # Transforms that are more complicated than a simple mapping of an old key name to a new one:
985-
# io.state_transform(
986-
# source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"),
987-
# target_key="decoder.layers.*.mlp.linear_fc1.weight",
988-
# fn=TransformFns.merge_fc1,
989-
# )
990-
# ],
991-
# )
992-
993-
# @property
994-
# def tokenizer(self):
995-
# """Gets the tokenizer for the model.
996-
997-
# Returns:
998-
# Tokenizer instance
999-
# """
1000-
# from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
1001-
1002-
# tokenizer = get_nmt_tokenizer(
1003-
# library=self.model_config.tokenizer_library,
1004-
# )
1005-
1006-
# return tokenizer
1007-
1008-
# @property
1009-
# def config(self) -> HyenaConfig:
1010-
# """Gets the model configuration.
1011-
1012-
# Returns:
1013-
# HyenaConfig: Model configuration
1014-
# """
1015-
# return self.model_config
1016-
1017-
1018-
# @io.model_importer(HyenaModel, "hf")
1019-
# class HuggingFaceSavannaHyenaImporter(PyTorchHyenaImporter):
1020-
# """Importer class for converting HuggingFace Savanna Hyena models to NeMo format.
1021-
1022-
# See: https://huggingface.co/arcinstitute/savanna_evo2_7b for an example of a savanna model that this can
1023-
# import and convert to NeMo format. Any of the Arc models that start with "savanna_" should work.
1024-
# """
1025-
1026-
# def get_source_model(self):
1027-
# """Returns the source model."""
1028-
# import huggingface_hub.errors
1029-
# from huggingface_hub import hf_hub_download
1030-
1031-
# if os.path.exists(str(self)):
1032-
# logging.info(f"Loading model from local path {self!s}")
1033-
# return torch.load(str(self), map_location="cpu", weights_only=False)
1034-
# else:
1035-
# if ":" in str(self):
1036-
# repo_id, revision = str(self).split(":")
1037-
# else:
1038-
# repo_id = str(self)
1039-
# revision = None
1040-
# # See HF download logic here:
1041-
# # https://github.com/ArcInstitute/evo2/blob/96ac9d9cd/evo2/models.py#L191-L231
1042-
# modelname = repo_id.split("/")[-1]
1043-
# download_dir = str(NEMO_MODELS_CACHE / repo_id)
1044-
# weights_filename = f"{modelname}.pt"
1045-
# try:
1046-
# weights_path = hf_hub_download(
1047-
# repo_id=repo_id, local_dir=download_dir, revision=revision, filename=weights_filename
1048-
# )
1049-
# except Exception:
1050-
# # Try downloading multi-part
1051-
# # If file is split, download and join parts
1052-
# logging.warning(f"Single path download failed, try loading checkpoint shards for {modelname}")
1053-
# # If file is split, get the first part's directory to use the same cache location
1054-
# weights_path = os.path.join(download_dir, weights_filename)
1055-
# if os.path.exists(weights_path):
1056-
# logging.info(f"Found {weights_path}")
1057-
# else:
1058-
# # Download and join parts
1059-
# parts = []
1060-
# part_num = 0
1061-
# while True:
1062-
# try:
1063-
# part_path = hf_hub_download(
1064-
# repo_id=repo_id,
1065-
# local_dir=download_dir,
1066-
# revision=revision,
1067-
# filename=f"{weights_filename}.part{part_num}",
1068-
# )
1069-
# parts.append(part_path)
1070-
# part_num += 1
1071-
# except huggingface_hub.errors.EntryNotFoundError:
1072-
# break
1073-
1074-
# # Join in the same directory
1075-
# with open(weights_path, "wb") as outfile:
1076-
# for part in parts:
1077-
# with open(part, "rb") as infile:
1078-
# while True:
1079-
# chunk = infile.read(8192 * 1024)
1080-
# if not chunk:
1081-
# break
1082-
# outfile.write(chunk)
1083-
1084-
# # Cleaning up the parts
1085-
# for part in parts:
1086-
# try:
1087-
# os.remove(part)
1088-
# except OSError as e:
1089-
# print(f"Error removing {part}: {e}")
1090-
# print("Cleaned up shards, final checkpoint saved to", weights_path)
1091-
1092-
# return torch.load(weights_path, map_location="cpu", weights_only=False)
1093-
1094-
1095767
HYENA_MODEL_OPTIONS: dict[str, Type[HyenaModelProvider]] = {
1096768
# ARC public checkpoint names (evo2_ prefix matches HuggingFace repo names)
1097769
"evo2_1b_base": Hyena1bModelProvider,

0 commit comments

Comments
 (0)