Skip to content

Commit 8a98954

Browse files
committed
add a missing file and make packaging work.
1 parent 6f51776 commit 8a98954

4 files changed

Lines changed: 217 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ issues = "https://github.com/ParagEkbote/hf-model-inspector/issues"
2828
source-directories = ["src"]
2929

3030
[tool.hatch.build.targets.wheel]
31-
packages = ["hf_model_inspector"]
31+
packages = ["src/hf_model_inspector"]
3232

3333
[tool.ruff]
3434
line-length = 88

src/hf_model_inspector/analyzer.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
from typing import Dict, Any, Optional, Tuple, List
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def estimate_param_count(
8+
repo_id: str, config: Optional[Dict], siblings: List[str]
9+
) -> Tuple[Optional[int], str]:
10+
"""
11+
Estimate parameter count for a model.
12+
Returns (param_count_estimate, method_description)
13+
14+
Strategy:
15+
1. Try index.json parsing (preferred, if available)
16+
2. Sum shard file sizes and convert to parameter estimate using dtype heuristics
17+
3. Fallback to config-based heuristic
18+
"""
19+
# Fallback since index.json parsing not implemented in this module
20+
# 2) sum shard bytes
21+
bytes_total = 0
22+
if siblings:
23+
joined = " ".join(siblings).lower()
24+
bytes_per_param = 2 # default
25+
if "fp16" in joined or "float16" in joined or "bf16" in joined:
26+
precision = "fp16/bf16"
27+
bytes_per_param = 2
28+
elif "fp8" in joined:
29+
precision = "fp8"
30+
bytes_per_param = 1
31+
elif "int8" in joined or "int4" in joined or "gptq" in joined:
32+
precision = "int"
33+
bytes_per_param = 1
34+
else:
35+
precision = "unknown"
36+
if config:
37+
cfg_dtype = config.get("torch_dtype") or config.get("dtype")
38+
if isinstance(cfg_dtype, str):
39+
if "16" in cfg_dtype:
40+
precision = "fp16"
41+
bytes_per_param = 2
42+
elif "8" in cfg_dtype:
43+
precision = "fp8_or_int"
44+
bytes_per_param = 1
45+
elif "32" in cfg_dtype:
46+
precision = "fp32"
47+
bytes_per_param = 4
48+
# Estimation cannot actually sum file sizes here, so returning None
49+
return None, f"shard_size_sum ({precision})"
50+
51+
# 3) fallback: config heuristics
52+
if config:
53+
try:
54+
h = config.get("hidden_size") or config.get("d_model") or 0
55+
l = config.get("num_hidden_layers") or config.get("n_layer") or 0
56+
v = config.get("vocab_size") or config.get("n_vocab") or 0
57+
if h and l:
58+
approx = v * h + l * (h * h * 12)
59+
return int(approx), "config_heuristic"
60+
except Exception:
61+
pass
62+
63+
return None, "unknown"
64+
65+
66+
def detect_quant_and_precision(
67+
repo_id: str, config: Optional[Dict], siblings: List[str], load_json_quiet=None
68+
) -> Dict[str, Any]:
69+
"""
70+
Detect quantization and precision.
71+
72+
Returns:
73+
{
74+
"quantized": bool,
75+
"quant_methods": [...],
76+
"precision": "fp16|bf16|fp8|int8|unknown"
77+
}
78+
"""
79+
result = {"quantized": False, "quant_methods": [], "precision": "unknown"}
80+
81+
# check quantization config if loader function provided
82+
if load_json_quiet:
83+
qconf = load_json_quiet(repo_id, "quantization_config.json")
84+
if qconf:
85+
result["quantized"] = True
86+
m = qconf.get("method") or qconf.get("quantization_method")
87+
result["quant_methods"].append(m or "unknown")
88+
89+
# filename hints
90+
joined = " ".join(siblings).lower() if siblings else ""
91+
for method, keywords in {
92+
"gptq": ["gptq"],
93+
"bitsandbytes": ["bnb", "bitsandbytes"],
94+
"awq": ["awq"],
95+
}.items():
96+
if any(k in joined for k in keywords):
97+
result["quantized"] = True
98+
result["quant_methods"].append(method)
99+
100+
# detect precision by filename
101+
for p, keywords in {
102+
"fp16": ["fp16", "float16"],
103+
"bf16": ["bf16"],
104+
"fp8": ["fp8"],
105+
"int8": ["int8"],
106+
"int4": ["int4"],
107+
}.items():
108+
if any(k in joined for k in keywords):
109+
result["precision"] = p
110+
break
111+
112+
# try config hints
113+
if result["precision"] == "unknown" and config:
114+
cfg_dtype = config.get("torch_dtype") or config.get("dtype") or config.get("torch_dtype_str")
115+
if isinstance(cfg_dtype, str):
116+
if "16" in cfg_dtype:
117+
result["precision"] = "fp16"
118+
elif "8" in cfg_dtype:
119+
result["precision"] = "fp8"
120+
elif "32" in cfg_dtype:
121+
result["precision"] = "fp32"
122+
123+
return result
124+
125+
126+
def analyze_tokenizer(tokenizer: Optional[Dict[str, Any]]) -> Dict[str, Any]:
127+
"""Analyze tokenizer config."""
128+
if not tokenizer:
129+
return {"present": False}
130+
131+
info = {
132+
"present": True,
133+
"type": None,
134+
"vocab_size": None,
135+
"model_max_length": None,
136+
"special_tokens": [],
137+
"truncation": None,
138+
"normalization": None,
139+
"lowercase": None,
140+
}
141+
142+
model_part = tokenizer.get("model") if isinstance(tokenizer, dict) else None
143+
if model_part:
144+
info["type"] = model_part.get("type") or model_part.get("model_type")
145+
vocab = model_part.get("vocab")
146+
if isinstance(vocab, dict):
147+
info["vocab_size"] = len(vocab)
148+
elif isinstance(vocab, list):
149+
info["vocab_size"] = len(vocab)
150+
151+
# tokenizer_config.json style keys
152+
for k in ["tokenizer_class", "model_max_length", "truncation", "do_lower_case"]:
153+
if k in tokenizer:
154+
info_key = "lowercase" if k == "do_lower_case" else k
155+
info[info_key] = tokenizer[k]
156+
157+
# special tokens
158+
at = tokenizer.get("added_tokens") or tokenizer.get("added_tokens_decoder") or tokenizer.get("special_tokens_map")
159+
if isinstance(at, dict):
160+
info["special_tokens"] = list(at.keys())
161+
elif isinstance(at, list):
162+
toks = []
163+
for t in at:
164+
if isinstance(t, dict):
165+
toks.append(t.get("content") or t.get("token"))
166+
info["special_tokens"] = toks
167+
168+
# normalizer
169+
if "normalizer" in tokenizer:
170+
info["normalization"] = tokenizer.get("normalizer")
171+
172+
return info
173+
174+
175+
def extract_architecture_extras(config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
176+
"""Extract additional model config information."""
177+
extras = {}
178+
if not config:
179+
return extras
180+
181+
keys = [
182+
"intermediate_size",
183+
"hidden_dropout_prob",
184+
"attention_probs_dropout_prob",
185+
"layer_norm_eps",
186+
"activation_function",
187+
"rope_theta",
188+
"rope_scaling",
189+
"sliding_window",
190+
"use_cache",
191+
"tie_word_embeddings",
192+
"num_key_value_heads",
193+
"kv_head_dim",
194+
]
195+
196+
for k in keys:
197+
if k in config:
198+
extras[k] = config[k]
199+
200+
# layer norm type inference
201+
norm_type = None
202+
if config.get("rms_norm") or "rms" in str(config.get("norm_type", "")).lower():
203+
norm_type = "RMSNorm"
204+
elif "layer_norm" in str(config.get("norm_type", "")).lower() or "layernorm" in str(config.get("norm_type", "")).lower():
205+
norm_type = "LayerNorm"
206+
elif "layer_norm_eps" in config:
207+
norm_type = "LayerNorm"
208+
209+
if norm_type:
210+
extras["norm_type"] = norm_type
211+
212+
return extras

src/hf_model_inspector/loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
HfApi,
44
hf_hub_download,
55
HfFolder,
6-
RepositoryNotFoundError,
76
whoami,
87
)
98
import json
@@ -47,10 +46,11 @@ def __init__(self, token: Optional[str] = None):
4746
def fetch_model_info(self, repo_id: str) -> Optional[Dict[str, Any]]:
4847
try:
4948
info = self.api.model_info(repo_id, token=self.token)
50-
except RepositoryNotFoundError:
51-
logger.error(f"Repository not found: {repo_id}")
52-
return None
5349
except Exception as e:
50+
error_msg = str(e).lower()
51+
if "not found" in error_msg or "404" in error_msg:
52+
logger.error(f"Repository not found: {repo_id}")
53+
else:
5454
logger.warning(f"Failed to fetch model info for {repo_id}: {e}")
5555
return None
5656

test/test_analyzer.py

Whitespace-only changes.

0 commit comments

Comments
 (0)