diff --git a/mlx_lm/convert.py b/mlx_lm/convert.py index da3502a..0a9d890 100644 --- a/mlx_lm/convert.py +++ b/mlx_lm/convert.py @@ -89,7 +89,7 @@ def mixed_quant_predicate( def convert( hf_path: str, - mlx_path: str = "mlx_model", + mlx_path: Union[str, Path] = "mlx_model", quantize: bool = False, q_group_size: int = 64, q_bits: int = 4, diff --git a/mlx_lm/evaluate.py b/mlx_lm/evaluate.py index 6289cec..bab5e2f 100644 --- a/mlx_lm/evaluate.py +++ b/mlx_lm/evaluate.py @@ -10,7 +10,7 @@ import os from importlib.metadata import version from pathlib import Path -from typing import Optional +from typing import Dict, Optional, Tuple, Union import lm_eval import mlx.core as mx @@ -169,7 +169,7 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]: ) # max length (prefix + completion) and longest common prefix per question. - length_stats = {} + length_stats: Dict[Tuple[int, ...], Tuple[int, float]] = {} for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8)) length_stats[prefix] = ( diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 71ba81d..334a8fa 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -289,7 +289,7 @@ def generate_step( model: nn.Module, *, max_tokens: int = 256, - sampler: Optional[Callable[mx.array, mx.array]] = None, + sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, @@ -297,7 +297,7 @@ def generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, - prompt_progress_callback: Optional[Callable[int, int]] = None, + prompt_progress_callback: Optional[Callable[[int, int], None]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -409,7 +409,7 @@ def speculative_generate_step( *, num_draft_tokens=2, max_tokens: int = 256, - sampler: Optional[Callable[mx.array, mx.array]] = None, + sampler: Optional[Callable[[mx.array], mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, prompt_cache: Optional[Any] = None, prefill_step_size: int = 512, diff --git a/mlx_lm/gguf.py b/mlx_lm/gguf.py index 241ac35..4ad3cf9 100644 --- a/mlx_lm/gguf.py +++ b/mlx_lm/gguf.py @@ -67,7 +67,7 @@ def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: def get_token_type( self, token_id: int, token_text: bytes, special_ids: Set[int] ) -> TokenType: - if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text): + if re.fullmatch(b"<0x[0-9A-Fa-f]{2}>", token_text): return TokenType.BYTE return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL @@ -77,12 +77,14 @@ def get_token_score(self, token_id: int) -> float: def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]: for text in self.added_tokens_list: if text in self.specials: - toktype = self.get_token_type(self.specials[text], "", self.special_ids) + toktype = self.get_token_type( + self.specials[text], text.encode("utf-8"), self.special_ids + ) score = self.get_token_score(self.specials[text]) else: toktype = TokenType.USER_DEFINED score = -1000.0 - yield text, score, toktype + yield text.encode("utf-8"), score, toktype def has_newline_token(self): return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab diff --git a/mlx_lm/merge.py b/mlx_lm/merge.py index c1d1826..a69334d 100644 --- a/mlx_lm/merge.py +++ b/mlx_lm/merge.py @@ -114,7 +114,7 @@ def unpack_values(vals): def merge( config: str, - mlx_path: str = "mlx_model", + mlx_path: Path = Path("mlx_model"), upload_repo: Optional[str] = None, ): with open(config, "r") as fid: diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index d666be6..3472223 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -45,12 +45,10 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] metadata (Dict[str, str]): Optional metadata to save along with model state. """ - cache_data = [c.state for c in cache] + cache_data = dict(tree_flatten([c.state for c in cache])) cache_info = [c.meta_state for c in cache] - cache_data = dict(tree_flatten(cache_data)) cache_classes = [type(c).__name__ for c in cache] - cache_metadata = [cache_info, metadata, cache_classes] - cache_metadata = dict(tree_flatten(cache_metadata)) + cache_metadata = dict(tree_flatten([cache_info, metadata, cache_classes])) mx.save_safetensors(file_name, cache_data, cache_metadata) @@ -87,7 +85,7 @@ def can_trim_prompt_cache(cache: List[Any]) -> bool: return all(c.is_trimmable() for c in cache) -def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: +def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int: """ Trim the model's cache by the given number of tokens. diff --git a/mlx_lm/models/cohere2.py b/mlx_lm/models/cohere2.py index 4fba1f5..0d95e99 100644 --- a/mlx_lm/models/cohere2.py +++ b/mlx_lm/models/cohere2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -61,7 +61,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape diff --git a/mlx_lm/models/deepseek_v2.py b/mlx_lm/models/deepseek_v2.py index 9964481..a09b2e8 100644 --- a/mlx_lm/models/deepseek_v2.py +++ b/mlx_lm/models/deepseek_v2.py @@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 - rope_scaling: Dict = None + rope_scaling: Dict[Any, Any] = {} attention_bias: bool = False @@ -189,6 +189,7 @@ def __init__(self, config: ModelArgs): ] if key in self.config.rope_scaling } + self.rope = DeepseekV2YarnRotaryEmbedding( dim=self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, @@ -244,7 +245,10 @@ def __call__( class DeepseekV2MLP(nn.Module): def __init__( - self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None + self, + config: ModelArgs, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, ): super().__init__() self.config = config diff --git a/mlx_lm/models/gpt2.py b/mlx_lm/models/gpt2.py index d557646..4891411 100644 --- a/mlx_lm/models/gpt2.py +++ b/mlx_lm/models/gpt2.py @@ -20,7 +20,7 @@ class ModelArgs(BaseModelArgs): n_positions: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None def __post_init__(self): if self.num_key_value_heads is None: diff --git a/mlx_lm/models/gpt_bigcode.py b/mlx_lm/models/gpt_bigcode.py index 1d9794b..b9b30d5 100644 --- a/mlx_lm/models/gpt_bigcode.py +++ b/mlx_lm/models/gpt_bigcode.py @@ -20,7 +20,7 @@ class ModelArgs(BaseModelArgs): n_positions: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None multi_query: bool = True attention_bias: bool = True mlp_bias: bool = True diff --git a/mlx_lm/models/gpt_neox.py b/mlx_lm/models/gpt_neox.py index 5e124a6..59fbf40 100644 --- a/mlx_lm/models/gpt_neox.py +++ b/mlx_lm/models/gpt_neox.py @@ -24,7 +24,7 @@ class ModelArgs(BaseModelArgs): vocab_size: int rotary_emb_base: int rotary_pct: float - num_key_value_heads: int = None + num_key_value_heads: Optional[int] = None def __post_init__(self): if self.num_key_value_heads is None: diff --git a/mlx_lm/models/hunyuan.py b/mlx_lm/models/hunyuan.py index 122cebd..94c9276 100644 --- a/mlx_lm/models/hunyuan.py +++ b/mlx_lm/models/hunyuan.py @@ -29,8 +29,8 @@ class ModelArgs(BaseModelArgs): rms_norm_eps: float rope_theta: float use_cla: bool - cla_share_factor: 2 - rope_scaling: Optional[Dict[str, Union[float, str]]] = None + cla_share_factor: int = 2 + rope_scaling: Dict[str, Union[float, str]] = {} tie_word_embeddings: bool = False def __post_init__(self): @@ -71,7 +71,6 @@ def __init__(self, kv_proj: bool, args: ModelArgs): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads head_dim = args.hidden_size // n_heads @@ -90,10 +89,12 @@ def __init__(self, kv_proj: bool, args: ModelArgs): self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + # Since rope_scaling is required, no need for a conditional check + scaling_alpha = float(args.rope_scaling["alpha"]) self.rope = DynamicNTKAlphaRoPE( head_dim, base=args.rope_theta, - scaling_alpha=args.rope_scaling["alpha"], + scaling_alpha=scaling_alpha, ) def __call__( diff --git a/mlx_lm/models/internlm2.py b/mlx_lm/models/internlm2.py index 28a095e..da122ee 100644 --- a/mlx_lm/models/internlm2.py +++ b/mlx_lm/models/internlm2.py @@ -20,7 +20,7 @@ class ModelArgs(BaseModelArgs): vocab_size: int bias: bool = True max_position_embeddings: int = 32768 - num_key_value_heads: int = None + num_key_value_heads: int rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None @@ -88,7 +88,7 @@ def __init__(self, args: ModelArgs): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - self.n_kv_groups = n_heads // args.num_key_value_heads + self.n_kv_groups = n_heads // n_kv_heads self.head_dim = head_dim = args.hidden_size // n_heads self.scale = head_dim**-0.5 @@ -99,7 +99,7 @@ def __init__(self, args: ModelArgs): self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias) rope_scale = ( - 1 / args.rope_scaling["factor"] + 1 / float(args.rope_scaling["factor"]) if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" else 2.0 ) diff --git a/mlx_lm/models/minicpm.py b/mlx_lm/models/minicpm.py index 089f6e0..9c83930 100644 --- a/mlx_lm/models/minicpm.py +++ b/mlx_lm/models/minicpm.py @@ -68,7 +68,7 @@ def __init__(self, args: ModelArgs): ) rope_scale = ( - 1 / args.rope_scaling["factor"] + 1 / float(args.rope_scaling["factor"]) if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" else 1 ) diff --git a/mlx_lm/models/openelm.py b/mlx_lm/models/openelm.py index 504fe95..1f6a570 100644 --- a/mlx_lm/models/openelm.py +++ b/mlx_lm/models/openelm.py @@ -29,7 +29,7 @@ class ModelArgs(BaseModelArgs): def make_divisible( v: Union[float, int], - divisor: Optional[int] = 8, + divisor: int = 8, min_value: Optional[Union[float, int]] = None, ) -> Union[float, int]: """ diff --git a/mlx_lm/sample_utils.py b/mlx_lm/sample_utils.py index 5b270eb..8c14705 100644 --- a/mlx_lm/sample_utils.py +++ b/mlx_lm/sample_utils.py @@ -73,7 +73,7 @@ def sampler(logits): def make_logits_processors( logit_bias: Optional[Dict[int, float]] = None, repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, + repetition_context_size: int = 20, # No longer optional ): """ Make logits processors for use with ``generate_step``. @@ -81,7 +81,7 @@ def make_logits_processors( Args: repetition_penalty (float, optional): The penalty factor for repeating tokens. - repetition_context_size (int, optional): The number of tokens to + repetition_context_size (int): The number of tokens to consider for repetition penalty. Default: ``20``. logit_bias (dictionary, optional): Additive logit bias. diff --git a/mlx_lm/server.py b/mlx_lm/server.py index e16e05d..2b5a328 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -163,7 +163,7 @@ def __init__(self, cli_args: argparse.Namespace): if self.cli_args.model is not None: self.load("default_model", draft_model_path="default_model") - def _validate_model_path(self, model_path: str): + def _validate_model_path(self, model_path: Union[str, Path]): model_path = Path(model_path) if model_path.exists() and not model_path.is_relative_to(Path.cwd()): raise RuntimeError( @@ -544,7 +544,7 @@ def handle_completion( to the stopping_criteria function """ tokens = [] - finish_reason = "length" + finish_reason: Union[Literal["length", "stop"], None] = "length" stop_sequence_suffix = None if self.stream: self.end_headers() @@ -594,10 +594,16 @@ def handle_completion( if self.logprobs > 0: sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1) - top_indices = sorted_indices[: self.logprobs] - top_logprobs = logprobs[top_indices] - top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(tuple(top_token_info)) + top_indices = sorted_indices[: self.logprobs].tolist() + top_logprobs = logprobs[top_indices].tolist() + top_tokens.append( + { + int(idx): float(prob) + for idx, prob in zip( + top_indices, top_logprobs + ) # Iterate over Python lists + } + ) token_logprobs.append(logprobs[token].item()) @@ -666,8 +672,8 @@ def handle_completion( def completion_usage_response( self, - prompt_token_count: Optional[int] = None, - completion_token_count: Optional[int] = None, + prompt_token_count: int, + completion_token_count: int, ): response = { "id": self.request_id, @@ -788,6 +794,7 @@ def run( *server_address, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE ) server_class.address_family, _, _, _, server_address = next(iter(infos)) + server_address = server_address[:2] httpd = server_class( server_address, lambda *args, **kwargs: handler_class( diff --git a/mlx_lm/tuner/trainer.py b/mlx_lm/tuner/trainer.py index a551d97..3ce1942 100644 --- a/mlx_lm/tuner/trainer.py +++ b/mlx_lm/tuner/trainer.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -157,8 +157,8 @@ def evaluate( batch_size, num_batches, max_seq_length=2048, - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, + loss: Callable = default_loss, + iterate_batches: Callable = iterate_batches, ): model.eval() all_losses = mx.array(0.0) @@ -204,9 +204,9 @@ def train( train_dataset, val_dataset, args: TrainingArgs = TrainingArgs(), - loss: callable = default_loss, - iterate_batches: callable = iterate_batches, - training_callback: TrainingCallback = None, + loss: Callable = default_loss, + iterate_batches: Callable = iterate_batches, + training_callback: Optional[TrainingCallback] = None, ): mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"]) print(f"Starting training..., iters: {args.iters}") @@ -244,7 +244,7 @@ def step(batch): n_tokens = 0 steps = 0 trained_tokens = 0 - train_time = 0 + train_time = 0.0 # Main training loop for it, batch in zip( range(1, args.iters + 1), diff --git a/mlx_lm/tuner/utils.py b/mlx_lm/tuner/utils.py index d63089a..3ebd8e9 100644 --- a/mlx_lm/tuner/utils.py +++ b/mlx_lm/tuner/utils.py @@ -2,7 +2,7 @@ import json import types from pathlib import Path -from typing import Dict +from typing import Dict, Union import mlx.core as mx import mlx.nn as nn @@ -166,13 +166,13 @@ def to_lora(layer): model.update_modules(tree_unflatten(lora_modules)) -def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module: +def load_adapters(model: nn.Module, adapter_path: Union[str, Path]) -> nn.Module: """ Load any fine-tuned adapters / layers. Args: model (nn.Module): The neural network model. - adapter_path (str): Path to the adapter configuration file. + adapter_path (Union[str, Path]): Path to the adapter configuration file. Returns: nn.Module: The updated model with LoRA layers applied. diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index ce0bcb7..e6f51d0 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -12,6 +12,7 @@ Any, Callable, Dict, + List, Optional, Tuple, Type, @@ -188,9 +189,9 @@ def load_model( for wf in weight_files: weights.update(mx.load(wf)) - model_class, model_args_class = get_model_classes(config=config) + model_class, model_args_class = get_model_classes(config) - model_args = model_args_class.from_dict(config) + model_args = model_args_class.from_dict(config=config) model = model_class(model_args) if hasattr(model, "sanitize"): @@ -274,24 +275,29 @@ def fetch_from_hub( return model, config, tokenizer -def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list: +def make_shards( + weights: Dict[str, mx.array], max_file_size_gb: int = MAX_FILE_SIZE_GB +) -> List[Dict[str, mx.array]]: """ Splits the weights into smaller shards. Args: - weights (dict): Model weights. + weights (Dict[str, mx.array]): Model weights. max_file_size_gb (int): Maximum size of each shard in gigabytes. Returns: - list: List of weight shards. + List[Dict[str, mx.array]]: List of weight shards. """ max_file_size_bytes = max_file_size_gb << 30 shards = [] - shard, shard_size = {}, 0 + shard = {} + shard_size = 0 + for k, v in weights.items(): if shard_size + v.nbytes > max_file_size_bytes: shards.append(shard) - shard, shard_size = {}, 0 + shard = {} + shard_size = 0 shard[k] = v shard_size += v.nbytes shards.append(shard) @@ -384,7 +390,6 @@ def save_weights( *, donate_weights: bool = False, ) -> None: - """Save model weights into specified directory.""" if isinstance(save_path, str): save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) @@ -397,11 +402,14 @@ def save_weights( else "model.safetensors" ) - total_size = sum(v.nbytes for v in weights.values()) - index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} - # Write the weights and make sure no references are kept other than the # necessary ones + total_size = sum(v.nbytes for v in weights.values()) + index_data = { + "metadata": {"total_size": total_size}, + "weight_map": {}, + } + if donate_weights: weights.clear() del weights @@ -423,11 +431,7 @@ def save_weights( } with open(save_path / "model.safetensors.index.json", "w") as f: - json.dump( - index_data, - f, - indent=4, - ) + json.dump(index_data, f, indent=4) def quantize_model(