-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvalidate.py
More file actions
147 lines (127 loc) · 5.45 KB
/
validate.py
File metadata and controls
147 lines (127 loc) · 5.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
WayInfer validation script.
Usage:
python validate.py --model <path.gguf> --prompt "What is 2+2?" --max-tokens 20
Handles tokenization, runs the native engine, decodes output, optionally compares with reference.
Requires: pip install llama-cpp-python (for tokenization only — model weights are NOT loaded).
"""
import argparse, struct, subprocess, sys, time, os
def main():
parser = argparse.ArgumentParser(description='WayInfer validation')
parser.add_argument('--model', required=True, help='Path to GGUF model file')
parser.add_argument('--prompt', default='What is 2+2?', help='Prompt to test')
parser.add_argument('--max-tokens', type=int, default=20, help='Max tokens to generate')
parser.add_argument('--reference', action='store_true', help='Also run llama-cpp-python reference (slow)')
script_dir = os.path.dirname(os.path.abspath(__file__))
default_engine = os.path.join(script_dir, 'build', 'wayinfer.exe')
parser.add_argument('--engine', default=default_engine, help='Path to native engine')
args = parser.parse_args()
if not os.path.exists(args.model):
print(f'Error: model not found: {args.model}'); sys.exit(1)
if not os.path.exists(args.engine):
print(f'Error: engine not found: {args.engine}')
print(f'Run build.cmd first'); sys.exit(1)
from llama_cpp import Llama
print(f'Model: {os.path.basename(args.model)}')
print(f'Prompt: {args.prompt}')
print()
# Step 1: Fast tokenizer load (vocab only, 0.2s)
print('Loading tokenizer...', end=' ', flush=True)
t0 = time.time()
tok = Llama(model_path=args.model, n_ctx=1, n_gpu_layers=0, verbose=False, vocab_only=True)
print(f'{time.time()-t0:.1f}s')
# Step 2: Format prompt + tokenize
model_lower = args.model.lower()
if 'mixtral' in model_lower or 'mistral' in model_lower:
formatted = f'[INST] {args.prompt} [/INST]'
else:
formatted = args.prompt
ids = tok.tokenize(formatted.encode(), add_bos=True, special=True)
bos = tok.token_bos()
if ids[0] != bos:
ids = [bos] + ids
print(f'Tokens: {len(ids)}')
# Step 3: Write token IDs
ids_file = '_validate_ids.bin'
with open(ids_file, 'wb') as f:
f.write(struct.pack('<I', len(ids)))
for t in ids: f.write(struct.pack('<I', t))
# Step 4: Run native engine
print(f'\n{"="*50}')
print(f' WayInfer')
print(f'{"="*50}')
t0 = time.time()
result = subprocess.run(
[args.engine, '--model', args.model,
'--ids-file', ids_file, '--greedy',
'--max-tokens', str(args.max_tokens)],
capture_output=True, text=True, timeout=1200
)
elapsed = time.time() - t0
# Parse tokens
our_tokens = []
for line in result.stdout.strip().split('\n'):
line = line.strip()
if line.startswith('TOKEN:'):
our_tokens.append(int(line.split(':')[1]))
# Decode
eos = tok.token_eos()
if our_tokens and our_tokens[-1] == eos:
our_tokens = our_tokens[:-1]
our_text = tok.detokenize(our_tokens).decode('utf-8', errors='replace') if our_tokens else '(no output)'
# Parse engine info from stderr
load_time = ''
splits_info = []
warnings = []
for line in result.stderr.split('\n'):
if 'Loading GGUF' in line:
load_time = line.strip().split('...')[-1].strip()
if 'split' in line and 'mapped' in line:
splits_info.append(line.strip())
if 'MoE' in line or 'layers' in line:
arch_info = line.strip()
if 'WARNING' in line:
warnings.append(line.strip())
print(f' Load: {load_time}')
for s in splits_info:
print(f' {s}')
if warnings:
print(f' Warnings: {len(warnings)} layers with missing tensors')
if len(warnings) > 5:
print(f' (model may use unsupported tensor names or quant format)')
print(f' Output: {our_text}')
print(f' Tokens: {len(our_tokens)} in {elapsed:.1f}s ({len(our_tokens)/max(elapsed,0.01):.2f} tok/s)')
# Step 5: Reference (optional, slow - loads full model)
if args.reference:
print(f'\n{"="*50}')
print(f' Reference (llama-cpp-python)')
print(f'{"="*50}')
print(' Loading full model (this is slow)...', flush=True)
t0 = time.time()
ref_llm = Llama(model_path=args.model, n_ctx=512, n_gpu_layers=0, verbose=False)
ref = ref_llm.create_chat_completion(
messages=[{'role': 'user', 'content': args.prompt}],
max_tokens=args.max_tokens, temperature=0
)
ref_text = ref['choices'][0]['message']['content']
ref_elapsed = time.time() - t0
print(f' Output: {ref_text}')
print(f' Time: {ref_elapsed:.1f}s')
# Compare
print(f'\n{"="*50}')
our_clean = our_text.strip().lstrip('\n').lstrip()
ref_clean = ref_text.strip().lstrip('\n').lstrip()
# Remove special tags like [RESP], [SRC], [AGENT] for comparison
import re
our_content = re.sub(r'\[.*?\]\s*', '', our_clean)
ref_content = re.sub(r'\[.*?\]\s*', '', ref_clean)
if our_content and ref_content and (ref_content.startswith(our_content) or our_content.startswith(ref_content)):
print(f' MATCH (content aligns)')
else:
print(f' Ours: {repr(our_content[:80])}')
print(f' Ref: {repr(ref_content[:80])}')
# Cleanup
try: os.remove(ids_file)
except: pass
if __name__ == '__main__':
main()