Skip to content

Commit d14bc6e

Browse files
committed
Refactor models.py into a folder
1 parent 2d06919 commit d14bc6e

File tree

9 files changed

+753
-635
lines changed

9 files changed

+753
-635
lines changed

src/commands/models.py

Lines changed: 0 additions & 635 deletions
This file was deleted.

src/commands/models/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
Models Command Group
3+
4+
Model management commands for alex-treBENCH.
5+
"""
6+
7+
import click
8+
from .list import models_list
9+
from .search import models_search
10+
from .info import models_info
11+
from .refresh import models_refresh
12+
from .cache import models_cache
13+
from .test import models_test
14+
from .costs import models_costs
15+
16+
17+
@click.group()
18+
def models():
19+
"""Model management commands."""
20+
pass
21+
22+
23+
# Register all subcommands
24+
models.add_command(models_list, name='list')
25+
models.add_command(models_search, name='search')
26+
models.add_command(models_info, name='info')
27+
models.add_command(models_refresh, name='refresh')
28+
models.add_command(models_cache, name='cache')
29+
models.add_command(models_test, name='test')
30+
models.add_command(models_costs, name='costs')

src/commands/models/cache.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
Models Cache Command
3+
4+
Manage model cache.
5+
"""
6+
7+
import click
8+
from rich.console import Console
9+
from rich.table import Table
10+
11+
from src.utils.logging import get_logger
12+
13+
console = Console()
14+
logger = get_logger(__name__)
15+
16+
17+
@click.command()
18+
@click.option('--clear', is_flag=True, help='Clear the model cache')
19+
@click.option('--info', is_flag=True, help='Show detailed cache information', default=True)
20+
@click.pass_context
21+
def models_cache(ctx, clear, info):
22+
"""Manage model cache."""
23+
try:
24+
from src.models.model_cache import get_model_cache
25+
26+
cache = get_model_cache()
27+
28+
if clear:
29+
if cache.clear_cache():
30+
console.print("[green]✓ Model cache cleared[/green]")
31+
else:
32+
console.print("[red]✗ Failed to clear cache[/red]")
33+
return
34+
35+
if info:
36+
cache_info = cache.get_cache_info()
37+
38+
# Cache status table
39+
status_table = Table(title="Model Cache Status")
40+
status_table.add_column("Property", style="cyan")
41+
status_table.add_column("Value", style="green")
42+
43+
status_table.add_row("Cache Path", cache_info['cache_path'])
44+
status_table.add_row("Exists", "✓ Yes" if cache_info['exists'] else "✗ No")
45+
status_table.add_row("Valid", "✓ Yes" if cache_info['valid'] else "✗ No")
46+
status_table.add_row("TTL", f"{cache_info['ttl_seconds']} seconds")
47+
48+
if cache_info['exists']:
49+
status_table.add_row("Size", f"{cache_info['size_bytes']:,} bytes")
50+
status_table.add_row("Model Count", str(cache_info['model_count']))
51+
52+
if cache_info['cached_at']:
53+
status_table.add_row("Cached At", cache_info['cached_at'])
54+
55+
if cache_info['age_seconds'] is not None:
56+
age_mins = cache_info['age_seconds'] / 60
57+
age_hours = age_mins / 60
58+
if age_hours > 1:
59+
age_str = f"{age_hours:.1f} hours"
60+
else:
61+
age_str = f"{age_mins:.1f} minutes"
62+
status_table.add_row("Age", age_str)
63+
64+
console.print(status_table)
65+
66+
# Cache recommendations
67+
if not cache_info['exists']:
68+
console.print("\n[yellow]💡 Run 'models refresh' to populate the cache[/yellow]")
69+
elif not cache_info['valid']:
70+
console.print("\n[yellow]💡 Cache has expired. Run 'models refresh' to update[/yellow]")
71+
else:
72+
console.print("\n[green]💡 Cache is up to date[/green]")
73+
74+
except Exception as e:
75+
console.print(f"[red]Error managing cache: {str(e)}[/red]")
76+
logger.exception("Cache management failed")

src/commands/models/costs.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
Models Costs Command
3+
4+
Estimate costs for running benchmarks with a model.
5+
"""
6+
7+
import asyncio
8+
import click
9+
from rich.console import Console
10+
from rich.table import Table
11+
12+
from src.utils.logging import get_logger
13+
14+
console = Console()
15+
logger = get_logger(__name__)
16+
17+
18+
@click.command()
19+
@click.option('--model', '-m', required=True, help='Model ID to estimate costs for')
20+
@click.option('--questions', '-q', type=int, default=100, help='Number of questions')
21+
@click.option('--input-tokens', type=int, help='Average input tokens per question')
22+
@click.option('--output-tokens', type=int, help='Average output tokens per question')
23+
@click.pass_context
24+
def models_costs(ctx, model, questions, input_tokens, output_tokens):
25+
"""Estimate costs for running benchmarks with a model."""
26+
27+
async def calculate_costs_async():
28+
try:
29+
from src.models.model_registry import model_registry
30+
from src.models.cost_calculator import CostCalculator
31+
32+
# Validate model using dynamic system
33+
models = await model_registry.get_available_models()
34+
model_info = None
35+
36+
for m in models:
37+
if m.get('id', '').lower() == model.lower():
38+
model_info = m
39+
break
40+
41+
if not model_info:
42+
console.print(f"[red]Model not found: {model}[/red]")
43+
console.print("[dim]Use 'models list' or 'models search' to find available models[/dim]")
44+
return
45+
46+
# Use defaults if not specified - fix variable scoping
47+
default_input_tokens = 100
48+
default_output_tokens = 50
49+
50+
config = ctx.obj.get('config') if ctx.obj else None
51+
if config and hasattr(config, 'costs') and hasattr(config.costs, 'estimation'):
52+
try:
53+
default_input_tokens = getattr(config.costs.estimation, 'default_input_tokens_per_question', 100)
54+
default_output_tokens = getattr(config.costs.estimation, 'default_output_tokens_per_question', 50)
55+
except AttributeError:
56+
pass # Use defaults
57+
58+
# Apply the values - use different variable names to avoid shadowing
59+
actual_input_tokens = input_tokens if input_tokens is not None else default_input_tokens
60+
actual_output_tokens = output_tokens if output_tokens is not None else default_output_tokens
61+
62+
# Calculate costs using the proper ModelRegistry method
63+
total_input_tokens = questions * actual_input_tokens
64+
total_output_tokens = questions * actual_output_tokens
65+
total_tokens = total_input_tokens + total_output_tokens
66+
67+
# Use ModelRegistry.estimate_cost for proper cost calculation
68+
from src.models.model_registry import ModelRegistry
69+
total_cost = ModelRegistry.estimate_cost(model, total_input_tokens, total_output_tokens)
70+
input_cost = ModelRegistry.estimate_cost(model, total_input_tokens, 0)
71+
output_cost = ModelRegistry.estimate_cost(model, 0, total_output_tokens)
72+
cost_per_question = total_cost / questions if questions > 0 else 0
73+
74+
# Get pricing information for display purposes
75+
pricing = model_info.get('pricing', {})
76+
input_cost_per_1m = pricing.get('input_cost_per_1m_tokens', 0)
77+
output_cost_per_1m = pricing.get('output_cost_per_1m_tokens', 0)
78+
79+
# If not found in dynamic model info, try static config
80+
if input_cost_per_1m == 0 and output_cost_per_1m == 0:
81+
static_config = ModelRegistry.get_model_config(model)
82+
if static_config:
83+
input_cost_per_1m = static_config.input_cost_per_1m_tokens
84+
output_cost_per_1m = static_config.output_cost_per_1m_tokens
85+
86+
# Display estimate
87+
table = Table(title=f"Cost Estimate: {model_info.get('name', model)}")
88+
table.add_column("Parameter", style="cyan")
89+
table.add_column("Value", style="green")
90+
91+
table.add_row("Model ID", model)
92+
table.add_row("Model Name", model_info.get('name', 'N/A'))
93+
table.add_row("Provider", (model_info.get('provider', 'Unknown')).title())
94+
table.add_row("Questions", f"{questions:,}")
95+
table.add_row("Input Tokens per Question", f"{actual_input_tokens:,}")
96+
table.add_row("Output Tokens per Question", f"{actual_output_tokens:,}")
97+
table.add_row("Total Input Tokens", f"{total_input_tokens:,}")
98+
table.add_row("Total Output Tokens", f"{total_output_tokens:,}")
99+
table.add_row("Total Tokens", f"{total_tokens:,}")
100+
table.add_row("Input Cost", f"${input_cost:.6f}")
101+
table.add_row("Output Cost", f"${output_cost:.6f}")
102+
table.add_row("Total Cost", f"${total_cost:.4f}")
103+
table.add_row("Cost per Question", f"${cost_per_question:.6f}")
104+
105+
console.print(table)
106+
107+
# Add context about pricing
108+
if input_cost_per_1m == 0 and output_cost_per_1m == 0:
109+
console.print("\n[yellow]⚠️ No pricing information available for this model[/yellow]")
110+
else:
111+
console.print(f"\n[dim]Based on: ${input_cost_per_1m:.2f}/${output_cost_per_1m:.2f} per 1M input/output tokens[/dim]")
112+
113+
except Exception as e:
114+
console.print(f"[red]Error calculating costs: {str(e)}[/red]")
115+
logger.exception("Cost calculation failed")
116+
117+
asyncio.run(calculate_costs_async())

src/commands/models/info.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Models Info Command
3+
4+
Show detailed information about a specific model.
5+
"""
6+
7+
import asyncio
8+
import click
9+
from rich.console import Console
10+
from rich.panel import Panel
11+
from rich.table import Table
12+
13+
from src.utils.logging import get_logger
14+
15+
console = Console()
16+
logger = get_logger(__name__)
17+
18+
19+
@click.command()
20+
@click.argument('model_id', required=True)
21+
@click.pass_context
22+
def models_info(ctx, model_id):
23+
"""Show detailed information about a specific model."""
24+
25+
async def show_model_info_async():
26+
try:
27+
from src.models.model_registry import model_registry
28+
29+
console.print(f"[blue]Getting information for model: {model_id}[/blue]")
30+
31+
# Get all models and find the specific one
32+
models = await model_registry.get_available_models()
33+
model_info = None
34+
35+
for model in models:
36+
if model.get('id', '').lower() == model_id.lower():
37+
model_info = model
38+
break
39+
40+
if not model_info:
41+
console.print(f"[red]Model not found: {model_id}[/red]")
42+
console.print("[dim]Use 'models list' or 'models search' to find available models[/dim]")
43+
44+
# Show similar models
45+
similar = model_registry.search_models(model_id.split('/')[-1], models)[:5]
46+
if similar:
47+
console.print(f"\n[yellow]Similar models:[/yellow]")
48+
for sim in similar:
49+
console.print(f" • {sim.get('id', 'N/A')}")
50+
return
51+
52+
# Display detailed information
53+
console.print(Panel.fit(
54+
f"[bold blue]{model_info.get('name', 'N/A')}[/bold blue]\n"
55+
f"[dim]{model_info.get('description', 'No description available')}[/dim]",
56+
title="Model Information",
57+
border_style="blue"
58+
))
59+
60+
# Basic details table
61+
details_table = Table(title="Model Details")
62+
details_table.add_column("Property", style="cyan")
63+
details_table.add_column("Value", style="green")
64+
65+
details_table.add_row("Model ID", model_info.get('id', 'N/A'))
66+
details_table.add_row("Provider", (model_info.get('provider', 'Unknown')).title())
67+
details_table.add_row("Context Length", f"{model_info.get('context_length', 0):,} tokens")
68+
details_table.add_row("Available", "✓ Yes" if model_info.get('available', True) else "✗ No")
69+
details_table.add_row("Modality", (model_info.get('modality', 'text')).title())
70+
71+
# Add architecture info if available
72+
architecture = model_info.get('architecture', {})
73+
if architecture:
74+
if 'tokenizer' in architecture:
75+
details_table.add_row("Tokenizer", architecture['tokenizer'])
76+
if 'instruct_type' in architecture:
77+
details_table.add_row("Instruction Type", architecture['instruct_type'])
78+
79+
console.print(details_table)
80+
81+
# Pricing table
82+
pricing = model_info.get('pricing', {})
83+
if pricing:
84+
pricing_table = Table(title="Pricing Information")
85+
pricing_table.add_column("Type", style="cyan")
86+
pricing_table.add_column("Cost per 1M tokens", style="yellow")
87+
88+
input_cost = pricing.get('input_cost_per_1m_tokens', 0)
89+
output_cost = pricing.get('output_cost_per_1m_tokens', 0)
90+
91+
# Format costs properly, handling scientific notation
92+
def format_cost(cost):
93+
if cost == 0:
94+
return "$0"
95+
# Check if values are already per-million-tokens (larger values) or per-token (very small values)
96+
if cost < 0.01:
97+
# Values are per-token, convert to per-million-tokens
98+
price_per_million = cost * 1_000_000
99+
else:
100+
# Values are already per-million-tokens
101+
price_per_million = cost
102+
103+
if price_per_million < 0.01:
104+
# For very small values, show more decimal places
105+
return f"${price_per_million:.4f}"
106+
elif price_per_million < 1:
107+
return f"${price_per_million:.2f}"
108+
else:
109+
return f"${price_per_million:.0f}"
110+
111+
pricing_table.add_row("Input", format_cost(input_cost))
112+
pricing_table.add_row("Output", format_cost(output_cost))
113+
pricing_table.add_row("Combined", format_cost(input_cost + output_cost))
114+
115+
console.print(pricing_table)
116+
117+
# Top provider info
118+
top_provider = model_info.get('top_provider', {})
119+
if top_provider:
120+
console.print(f"\n[bold]Top Provider:[/bold]")
121+
console.print(f"• Max completion tokens: {top_provider.get('max_completion_tokens', 'N/A')}")
122+
console.print(f"• Max throughput: {top_provider.get('max_throughput_tokens_per_minute', 'N/A')} tokens/min")
123+
124+
# Per-request limits
125+
limits = model_info.get('per_request_limits', {})
126+
if limits:
127+
console.print(f"\n[bold]Request Limits:[/bold]");
128+
if 'prompt_tokens' in limits:
129+
console.print(f"• Max prompt tokens: {limits['prompt_tokens']:,}")
130+
if 'completion_tokens' in limits:
131+
console.print(f"• Max completion tokens: {limits['completion_tokens']:,}")
132+
133+
except Exception as e:
134+
console.print(f"[red]Error getting model info: {str(e)}[/red]")
135+
logger.exception("Model info retrieval failed")
136+
137+
asyncio.run(show_model_info_async())

0 commit comments

Comments
 (0)