@@ -764,334 +764,6 @@ class HyenaNV1b2ModelProvider(HyenaNV1bModelProvider):
764764 # glu_linear_offset: float = 1.0
765765
766766
767- # FIXME use the following as a starting point for the new megatron bridge style model importer/exporter.
768- # @io.model_importer(HyenaModel, "pytorch")
769- # class PyTorchHyenaImporter(io.ModelConnector["HyenaModel", HyenaModel]):
770- # """Importer class for converting PyTorch Hyena models to NeMo format."""
771-
772- # def __new__(cls, path: str, model_config=None):
773- # """Creates a new importer instance.
774-
775- # Args:
776- # path: Path to the PyTorch model
777- # model_config: Optional model configuration
778-
779- # Returns:
780- # PyTorchHyenaImporter instance
781- # """
782- # instance = super().__new__(cls, path)
783- # instance.model_config = model_config
784- # return instance
785-
786- # def init(self) -> HyenaModel:
787- # """Initializes a new HyenaModel instance.
788-
789- # Returns:
790- # HyenaModel: Initialized model
791- # """
792- # return HyenaModel(self.config, tokenizer=self.tokenizer)
793-
794- # def get_source_model(self):
795- # """Returns the source model."""
796- # return torch.load(str(self), map_location="cpu")
797-
798- # def apply(self, output_path: Path, checkpoint_format: str = "torch_dist") -> Path:
799- # """Applies the model conversion from PyTorch to NeMo format.
800-
801- # Args:
802- # output_path: Path to save the converted model
803- # checkpoint_format: Format for saving checkpoints
804-
805- # Returns:
806- # Path: Path to the saved NeMo model
807- # """
808- # source = self.get_source_model()
809-
810- # if "model" in source:
811- # source = source["model"]
812-
813- # class ModelState:
814- # """Wrapper around the source model state dictionary that also handles some weight transformations."""
815-
816- # def __init__(self, state_dict, num_layers, fp32_suffixes):
817- # """Wrapper around the source model state dictionary that also handles some weight transformations.
818-
819- # Args:
820- # state_dict: original state dictionary from the source model
821- # num_layers: number of layers in the source model
822- # fp32_suffixes: suffixes of the weights that should be converted to float32
823- # """
824- # self.num_layers = num_layers
825- # state_dict = self.transform_source_dict(state_dict)
826- # self._state_dict = state_dict
827- # self.fp32_suffixes = fp32_suffixes
828-
829- # def state_dict(self):
830- # """Return the state dictionary."""
831- # return self._state_dict
832-
833- # def to(self, dtype):
834- # """Convert the state dictionary to the target dtype."""
835- # for k, v in self._state_dict.items():
836- # if "_extra" not in k:
837- # if v.dtype != dtype:
838- # logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)")
839- # k_suffix = k.split(".")[-1]
840- # if k_suffix in self.fp32_suffixes:
841- # _dtype = torch.float32
842- # else:
843- # _dtype = dtype
844- # self._state_dict[k] = v.to(_dtype)
845-
846- # def adjust_medium_filter(self, updated_data):
847- # """Adjust the medium filter."""
848- # from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig
849-
850- # for k, v in updated_data.items():
851- # if "filter.h" in k or "filter.decay" in k:
852- # updated_data[k] = v[:, : HyenaConfig().hyena_medium_conv_len]
853- # return updated_data
854-
855- # def transform_source_dict(self, source):
856- # """Transform the source state dictionary.
857-
858- # This function works by applying some challenging layer name re-mappings and
859- # removing extra keys, as well as truncating a filter that didn't need to extend to the full
860- # sequence length dim.
861- # """
862- # import re
863-
864- # layer_map = {i + 2: i for i in range(self.num_layers)}
865- # layer_map[self.num_layers + 3] = self.num_layers
866- # updated_data = {}
867-
868- # for key in list(source["module"].keys()):
869- # if "_extra" in key:
870- # source["module"].pop(key)
871- # else:
872- # match = re.search(r"sequential\.(\d+)", key)
873- # if match:
874- # original_layer_num = int(match.group(1))
875- # if original_layer_num in layer_map:
876- # # Create the updated key by replacing the layer number
877- # new_key = re.sub(rf"\b{original_layer_num}\b", str(layer_map[original_layer_num]), key)
878- # updated_data[new_key] = source["module"][key]
879- # else:
880- # # Keep the key unchanged if no mapping exists
881- # updated_data[key] = source["module"][key]
882- # else:
883- # updated_data[key] = source["module"][key]
884- # updated_data = self.adjust_medium_filter(updated_data)
885- # return updated_data
886-
887- # target = self.init()
888- # trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format=checkpoint_format)
889- # target.to(self.config.params_dtype)
890- # fp32_suffixes = {n.split(".")[-1] for n, p in target.named_parameters() if p.dtype == torch.float32}
891- # source = ModelState(source, self.config.num_layers, fp32_suffixes)
892- # source.to(self.config.params_dtype)
893- # self.convert_state(source, target)
894- # self.nemo_save(output_path, trainer)
895-
896- # logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")
897-
898- # teardown(trainer, target)
899- # del trainer, target
900-
901- # return output_path
902-
903- # def convert_state(self, source, target):
904- # """Converts the state dictionary from source format to target format.
905-
906- # Args:
907- # source: Source model state
908- # target: Target model
909-
910- # Returns:
911- # Result of applying state transforms
912- # """
913- # mapping = {}
914- # mapping["sequential.0.word_embeddings.weight"] = "embedding.word_embeddings.weight"
915- # mapping[f"sequential.{len(self.config.hybrid_override_pattern)}.norm.weight"] = "decoder.final_norm.weight"
916- # te_enabled = self.config.use_te
917- # for i, symbol in enumerate(self.config.hybrid_override_pattern):
918- # if te_enabled:
919- # mapping[f"sequential.{i}.pre_mlp_layernorm.weight"] = (
920- # f"decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight"
921- # )
922- # else:
923- # mapping[f"sequential.{i}.pre_mlp_layernorm.weight"] = f"decoder.layers.{i}.pre_mlp_layernorm.weight"
924- # mapping[f"sequential.{i}.mlp.w3.weight"] = f"decoder.layers.{i}.mlp.linear_fc2.weight"
925-
926- # if symbol != "*":
927- # if te_enabled:
928- # mapping[f"sequential.{i}.input_layernorm.weight"] = (
929- # f"decoder.layers.{i}.mixer.dense_projection.layer_norm_weight"
930- # )
931- # else:
932- # mapping[f"sequential.{i}.input_layernorm.weight"] = f"decoder.layers.{i}.norm.weight"
933-
934- # mapping[f"sequential.{i}.mixer.dense_projection.weight"] = (
935- # f"decoder.layers.{i}.mixer.dense_projection.weight"
936- # )
937- # mapping[f"sequential.{i}.mixer.hyena_proj_conv.short_conv_weight"] = (
938- # f"decoder.layers.{i}.mixer.hyena_proj_conv.short_conv_weight"
939- # )
940- # mapping[f"sequential.{i}.mixer.dense.weight"] = f"decoder.layers.{i}.mixer.dense.weight"
941- # mapping[f"sequential.{i}.mixer.dense.bias"] = f"decoder.layers.{i}.mixer.dense.bias"
942-
943- # if symbol == "S":
944- # mapping[f"sequential.{i}.mixer.mixer.short_conv.short_conv_weight"] = (
945- # f"decoder.layers.{i}.mixer.mixer.short_conv.short_conv_weight"
946- # )
947-
948- # elif symbol == "D":
949- # mapping[f"sequential.{i}.mixer.mixer.conv_bias"] = f"decoder.layers.{i}.mixer.mixer.conv_bias"
950- # mapping[f"sequential.{i}.mixer.mixer.filter.h"] = f"decoder.layers.{i}.mixer.mixer.filter.h"
951- # mapping[f"sequential.{i}.mixer.mixer.filter.decay"] = (
952- # f"decoder.layers.{i}.mixer.mixer.filter.decay"
953- # )
954-
955- # elif symbol == "H":
956- # mapping[f"sequential.{i}.mixer.mixer.conv_bias"] = f"decoder.layers.{i}.mixer.mixer.conv_bias"
957- # mapping[f"sequential.{i}.mixer.mixer.filter.gamma"] = (
958- # f"decoder.layers.{i}.mixer.mixer.filter.gamma"
959- # )
960- # mapping[f"sequential.{i}.mixer.mixer.filter.R"] = f"decoder.layers.{i}.mixer.mixer.filter.R"
961- # mapping[f"sequential.{i}.mixer.mixer.filter.p"] = f"decoder.layers.{i}.mixer.mixer.filter.p"
962-
963- # elif symbol == "*":
964- # if te_enabled:
965- # mapping[f"sequential.{i}.input_layernorm.weight"] = (
966- # f"decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight"
967- # )
968- # else:
969- # mapping[f"sequential.{i}.input_layernorm.weight"] = f"decoder.layers.{i}.input_layernorm.weight"
970-
971- # mapping[f"sequential.{i}.mixer.dense_projection.weight"] = (
972- # f"decoder.layers.{i}.self_attention.linear_qkv.weight"
973- # )
974- # mapping[f"sequential.{i}.mixer.dense.weight"] = f"decoder.layers.{i}.self_attention.linear_proj.weight"
975- # mapping[f"sequential.{i}.mixer.dense.bias"] = f"decoder.layers.{i}.self_attention.linear_proj.bias"
976- # else:
977- # raise ValueError(f"Unknown symbol: {symbol}")
978-
979- # return io.apply_transforms(
980- # source,
981- # target,
982- # mapping=mapping,
983- # transforms=[
984- # # Transforms that are more complicated than a simple mapping of an old key name to a new one:
985- # io.state_transform(
986- # source_key=("sequential.*.mlp.w1.weight", "sequential.*.mlp.w2.weight"),
987- # target_key="decoder.layers.*.mlp.linear_fc1.weight",
988- # fn=TransformFns.merge_fc1,
989- # )
990- # ],
991- # )
992-
993- # @property
994- # def tokenizer(self):
995- # """Gets the tokenizer for the model.
996-
997- # Returns:
998- # Tokenizer instance
999- # """
1000- # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
1001-
1002- # tokenizer = get_nmt_tokenizer(
1003- # library=self.model_config.tokenizer_library,
1004- # )
1005-
1006- # return tokenizer
1007-
1008- # @property
1009- # def config(self) -> HyenaConfig:
1010- # """Gets the model configuration.
1011-
1012- # Returns:
1013- # HyenaConfig: Model configuration
1014- # """
1015- # return self.model_config
1016-
1017-
1018- # @io.model_importer(HyenaModel, "hf")
1019- # class HuggingFaceSavannaHyenaImporter(PyTorchHyenaImporter):
1020- # """Importer class for converting HuggingFace Savanna Hyena models to NeMo format.
1021-
1022- # See: https://huggingface.co/arcinstitute/savanna_evo2_7b for an example of a savanna model that this can
1023- # import and convert to NeMo format. Any of the Arc models that start with "savanna_" should work.
1024- # """
1025-
1026- # def get_source_model(self):
1027- # """Returns the source model."""
1028- # import huggingface_hub.errors
1029- # from huggingface_hub import hf_hub_download
1030-
1031- # if os.path.exists(str(self)):
1032- # logging.info(f"Loading model from local path {self!s}")
1033- # return torch.load(str(self), map_location="cpu", weights_only=False)
1034- # else:
1035- # if ":" in str(self):
1036- # repo_id, revision = str(self).split(":")
1037- # else:
1038- # repo_id = str(self)
1039- # revision = None
1040- # # See HF download logic here:
1041- # # https://github.com/ArcInstitute/evo2/blob/96ac9d9cd/evo2/models.py#L191-L231
1042- # modelname = repo_id.split("/")[-1]
1043- # download_dir = str(NEMO_MODELS_CACHE / repo_id)
1044- # weights_filename = f"{modelname}.pt"
1045- # try:
1046- # weights_path = hf_hub_download(
1047- # repo_id=repo_id, local_dir=download_dir, revision=revision, filename=weights_filename
1048- # )
1049- # except Exception:
1050- # # Try downloading multi-part
1051- # # If file is split, download and join parts
1052- # logging.warning(f"Single path download failed, try loading checkpoint shards for {modelname}")
1053- # # If file is split, get the first part's directory to use the same cache location
1054- # weights_path = os.path.join(download_dir, weights_filename)
1055- # if os.path.exists(weights_path):
1056- # logging.info(f"Found {weights_path}")
1057- # else:
1058- # # Download and join parts
1059- # parts = []
1060- # part_num = 0
1061- # while True:
1062- # try:
1063- # part_path = hf_hub_download(
1064- # repo_id=repo_id,
1065- # local_dir=download_dir,
1066- # revision=revision,
1067- # filename=f"{weights_filename}.part{part_num}",
1068- # )
1069- # parts.append(part_path)
1070- # part_num += 1
1071- # except huggingface_hub.errors.EntryNotFoundError:
1072- # break
1073-
1074- # # Join in the same directory
1075- # with open(weights_path, "wb") as outfile:
1076- # for part in parts:
1077- # with open(part, "rb") as infile:
1078- # while True:
1079- # chunk = infile.read(8192 * 1024)
1080- # if not chunk:
1081- # break
1082- # outfile.write(chunk)
1083-
1084- # # Cleaning up the parts
1085- # for part in parts:
1086- # try:
1087- # os.remove(part)
1088- # except OSError as e:
1089- # print(f"Error removing {part}: {e}")
1090- # print("Cleaned up shards, final checkpoint saved to", weights_path)
1091-
1092- # return torch.load(weights_path, map_location="cpu", weights_only=False)
1093-
1094-
1095767HYENA_MODEL_OPTIONS : dict [str , Type [HyenaModelProvider ]] = {
1096768 # ARC public checkpoint names (evo2_ prefix matches HuggingFace repo names)
1097769 "evo2_1b_base" : Hyena1bModelProvider ,
0 commit comments