Skip to content

Commit 8ddee79

Browse files
committed
Add test to fix the model name is empty error
1 parent f6f08db commit 8ddee79

2 files changed

Lines changed: 108 additions & 4 deletions

File tree

src/pg2_benchmark/model_card.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import frontmatter
2-
from pydantic import BaseModel, Field, ConfigDict
31
from pathlib import Path
4-
from typing import Self, Any
2+
from typing import Any, Self
3+
4+
import frontmatter
5+
from pydantic import BaseModel, ConfigDict, Field
56

67

78
class ModelCard(BaseModel):
@@ -20,7 +21,7 @@ class ModelCard(BaseModel):
2021

2122
model_config = ConfigDict(extra="allow")
2223

23-
name: str = ""
24+
name: str
2425
hyper_params: dict[str, Any] = Field(default_factory=dict)
2526

2627
@classmethod

tests/test_validate.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from pathlib import Path
2+
from unittest.mock import patch
3+
4+
import pytest
5+
from typer.testing import CliRunner
6+
7+
from pg2_benchmark.cli.validate import ModelPath, validate_app
8+
9+
10+
@pytest.fixture
11+
def runner():
12+
"""Test runner for CLI commands."""
13+
return CliRunner()
14+
15+
16+
@pytest.fixture
17+
def valid_model_card_content() -> str:
18+
"""Valid model card content for testing."""
19+
return """---
20+
name: "test_model"
21+
hyper_params:
22+
learning_rate: 0.001
23+
batch_size: 32
24+
---
25+
26+
# Test Model
27+
This is a test model card.
28+
"""
29+
30+
31+
@pytest.fixture
32+
def invalid_model_card_content() -> str:
33+
"""Invalid model card content for testing."""
34+
return """---
35+
invalid_yaml: [
36+
---
37+
38+
# Invalid Model Card
39+
"""
40+
41+
42+
def test_model_card_validation_success(
43+
tmp_path: Path, valid_model_card_content: str, runner: CliRunner
44+
):
45+
"""Test successful model card validation."""
46+
model_name = "test_model"
47+
model_dir = tmp_path / "models" / model_name
48+
model_dir.mkdir(parents=True)
49+
50+
model_card_file = model_dir / "README.md"
51+
model_card_file.write_text(valid_model_card_content)
52+
53+
with patch.object(ModelPath, "ROOT_PATH", tmp_path / "models"):
54+
result = runner.invoke(validate_app, [model_name])
55+
56+
assert result.exit_code == 0
57+
assert "✅ Loaded test_model" in result.stdout
58+
assert "learning_rate" in result.stdout and "batch_size" in result.stdout
59+
60+
61+
def test_model_card_validation_missing_file(tmp_path: Path, runner: CliRunner):
62+
"""Test validation when model card file doesn't exist."""
63+
model_name = "nonexistent_model"
64+
65+
with patch.object(ModelPath, "ROOT_PATH", tmp_path / "models"):
66+
result = runner.invoke(validate_app, [model_name])
67+
68+
assert result.exit_code == 1
69+
assert "❌ Model nonexistent_model does not have a model card" in result.stdout
70+
71+
72+
def test_model_card_validation_invalid_content(
73+
tmp_path: Path, invalid_model_card_content: str, runner: CliRunner
74+
):
75+
"""Test validation with invalid model card content."""
76+
model_name = "invalid_model"
77+
model_dir = tmp_path / "models" / model_name
78+
model_dir.mkdir(parents=True)
79+
80+
model_card_file = model_dir / "README.md"
81+
model_card_file.write_text(invalid_model_card_content)
82+
83+
with patch.object(ModelPath, "ROOT_PATH", tmp_path / "models"):
84+
result = runner.invoke(validate_app, [model_name])
85+
86+
assert result.exit_code == 1
87+
assert "❌ Error loading model card" in result.stdout
88+
89+
90+
def test_model_card_validation_empty_file(tmp_path: Path, runner: CliRunner):
91+
"""Test validation with empty model card file."""
92+
model_name = "empty_model"
93+
model_dir = tmp_path / "models" / model_name
94+
model_dir.mkdir(parents=True)
95+
96+
model_card_file = model_dir / "README.md"
97+
model_card_file.write_text("")
98+
99+
with patch.object(ModelPath, "ROOT_PATH", tmp_path / "models"):
100+
result = runner.invoke(validate_app, [model_name])
101+
102+
assert result.exit_code == 1
103+
assert "❌ Error loading model card" in result.stdout

0 commit comments

Comments
 (0)