Skip to content

Modified OpenELM.py, mincpm.py, gpt2.py, gpt_bigcode.py, internlm2.py code to make it work with mypy #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlx_lm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def mixed_quant_predicate(

def convert(
hf_path: str,
mlx_path: str = "mlx_model",
mlx_path: Union[str, Path] = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
upload_repo: Optional[str] = None,
revision: Optional[str] = None,
dequantize: bool = False,
quant_predicate: Optional[
Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional
from typing import Dict, Optional, Tuple, Union

import lm_eval
import mlx.core as mx
Expand Down Expand Up @@ -169,7 +169,7 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]:
)

# max length (prefix + completion) and longest common prefix per question.
length_stats = {}
length_stats: Dict[Tuple[int, ...], Tuple[int, float]] = {}
for prefix, completed in zip(tokenized[0::2], tokenized[1::2]):
max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8))
length_stats[prefix] = (
Expand Down
6 changes: 3 additions & 3 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,15 @@ def generate_step(
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 2048,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
prompt_progress_callback: Optional[Callable[[int, int], None]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Expand Down Expand Up @@ -391,7 +391,7 @@ def speculative_generate_step(
*,
num_draft_tokens=2,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
Expand Down
8 changes: 5 additions & 3 deletions mlx_lm/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
if re.fullmatch(b"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL

Expand All @@ -77,12 +77,14 @@ def get_token_score(self, token_id: int) -> float:
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
toktype = self.get_token_type(
self.specials[text], text.encode("utf-8"), self.special_ids
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look like the same behavior. I'm just wondering what motivated the change here and if this code is still working? Or maybe it wasn't working before?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it gave some mypy errors initially. Doing this change fixed them.

)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text, score, toktype
yield text.encode("utf-8"), score, toktype

def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def unpack_values(vals):

def merge(
config: str,
mlx_path: str = "mlx_model",
mlx_path: Path = Path("mlx_model"),
upload_repo: Optional[str] = None,
):
with open(config, "r") as fid:
Expand Down
8 changes: 3 additions & 5 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str]
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_data = dict(tree_flatten([c.state for c in cache]))
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
cache_metadata = dict(tree_flatten([cache_info, metadata, cache_classes]))
mx.save_safetensors(file_name, cache_data, cache_metadata)


Expand Down Expand Up @@ -87,7 +85,7 @@ def can_trim_prompt_cache(cache: List[Any]) -> bool:
return all(c.is_trimmable() for c in cache)


def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> int:
"""
Trim the model's cache by the given number of tokens.

Expand Down
4 changes: 2 additions & 2 deletions mlx_lm/models/cohere2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -61,7 +61,7 @@ def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape

Expand Down
52 changes: 6 additions & 46 deletions mlx_lm/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
rope_scaling: Dict[Any, Any] = {}
attention_bias: bool = False


Expand Down Expand Up @@ -189,6 +189,7 @@ def __init__(self, config: ModelArgs):
]
if key in self.config.rope_scaling
}

self.rope = DeepseekV2YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
Expand All @@ -197,54 +198,13 @@ def __init__(self, config: ModelArgs):
**rope_kwargs,
)

def __call__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's happening here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. How it got deleted. Wasn't intentional. Probably while fixing merge conflicts maybe? Just guessing

self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape

if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))

q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)

k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)

if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)

queries = mx.concatenate([q_nope, q_pe], axis=-1)

output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class DeepseekV2MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ModelArgs(BaseModelArgs):
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None

def __post_init__(self):
if self.num_key_value_heads is None:
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ModelArgs(BaseModelArgs):
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None
multi_query: bool = True
attention_bias: bool = True
mlp_bias: bool = True
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ModelArgs(BaseModelArgs):
vocab_size: int
rotary_emb_base: int
rotary_pct: float
num_key_value_heads: int = None
num_key_value_heads: Optional[int] = None

def __post_init__(self):
if self.num_key_value_heads is None:
Expand Down
9 changes: 5 additions & 4 deletions mlx_lm/models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ModelArgs(BaseModelArgs):
rms_norm_eps: float
rope_theta: float
use_cla: bool
cla_share_factor: 2
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
cla_share_factor: int = 2
rope_scaling: Dict[str, Union[float, str]] = {}
tie_word_embeddings: bool = False

def __post_init__(self):
Expand Down Expand Up @@ -71,7 +71,6 @@ def __init__(self, kv_proj: bool, args: ModelArgs):

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads

head_dim = args.hidden_size // n_heads
Expand All @@ -90,10 +89,12 @@ def __init__(self, kv_proj: bool, args: ModelArgs):
self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)
self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps)

# Since rope_scaling is required, no need for a conditional check
scaling_alpha = float(args.rope_scaling["alpha"])
self.rope = DynamicNTKAlphaRoPE(
head_dim,
base=args.rope_theta,
scaling_alpha=args.rope_scaling["alpha"],
scaling_alpha=scaling_alpha,
)

def __call__(
Expand Down
6 changes: 3 additions & 3 deletions mlx_lm/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ModelArgs(BaseModelArgs):
vocab_size: int
bias: bool = True
max_position_embeddings: int = 32768
num_key_value_heads: int = None
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(self, args: ModelArgs):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_groups = n_heads // args.num_key_value_heads
self.n_kv_groups = n_heads // n_kv_heads

self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
Expand All @@ -99,7 +99,7 @@ def __init__(self, args: ModelArgs):
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)

rope_scale = (
1 / args.rope_scaling["factor"]
1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 2.0
)
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ModelArgs(BaseModelArgs):
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
rope_scaling: Dict[str, Union[float, str]] = {}
tie_word_embeddings: bool = True

def __post_init__(self):
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, args: ModelArgs):
)

rope_scale = (
1 / args.rope_scaling["factor"]
1 / float(args.rope_scaling["factor"])
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
Expand Down
2 changes: 1 addition & 1 deletion mlx_lm/models/openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ModelArgs(BaseModelArgs):

def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
divisor: int = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
Expand Down
10 changes: 5 additions & 5 deletions mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def make_sampler(
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
top_k: int = -1,
) -> Callable[mx.array, mx.array]:
) -> Callable[[mx.array], mx.array]:
"""
Make a sampler function for use with ``generate_step``.

Expand Down Expand Up @@ -57,17 +57,17 @@ def sampler(logits):


def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
logit_bias: Dict[int, float] = None,
repetition_penalty: float = None,
repetition_context_size: int = 20, # No longer optional
):
"""
Make logits processors for use with ``generate_step``.

Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
repetition_context_size (int): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.

Expand Down
Loading