Skip to content

Commit ba00c75

Browse files
committed
Add setup_logger in main entrypoint
1 parent f16e12d commit ba00c75

File tree

7 files changed

+424
-58
lines changed

7 files changed

+424
-58
lines changed

pyproject.toml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[project]
22
name = "pg2-benchmark"
3-
version = "0.1.0"
43
description = "Add your description here"
54
readme = "README.md"
65
authors = [
@@ -10,6 +9,7 @@ authors = [
109
{ name = "Henning Redestig", email = "Henning.Redestig@iff.com" },
1110
{ name = "Karel van der Weg", email = "karel.weg@iff.com" },
1211
]
12+
dynamic = ["version"] # Dynamically set hatch version using git tags
1313
requires-python = ">=3.12"
1414
dependencies = [
1515
"typer>=0.15.2",
@@ -28,18 +28,21 @@ pg2-benchmark = "pg2_benchmark.__main__:app"
2828
requires = ["hatchling"]
2929
build-backend = "hatchling.build"
3030

31+
[tool.hatch.version]
32+
path = "src/pg2_benchmark/__about__.py"
33+
34+
[tool.hatch.build.targets.wheel]
35+
packages = ["src/pg2_benchmark"]
36+
3137
[tool.uv.sources]
3238
pg2-dataset = { git = "https://github.com/ProteinGym2/pg2-dataset.git", rev = "58c327e13bade1effe1312eb2b8d5445016a5a8f" }
3339

3440
[tool.ruff]
3541
line-length = 88
3642

37-
[tool.click]
38-
logging_level = "INFO"
39-
logging_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
40-
4143
[dependency-groups]
4244
dev = [
45+
"hatch>=1.14.1",
4346
"isort>=6.0.1",
4447
"pre-commit>=4.2.0",
4548
"pytest>=8.4.1",

src/pg2_benchmark/__about__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Metadata file to store version used by hatch.
2+
3+
> BE CAREFUL WITH MANUALLY CHANGING THIS FILE!
4+
5+
References:
6+
https://hatch.pypa.io/latest/version/
7+
"""
8+
9+
__version__ = "0.1.0b1"

src/pg2_benchmark/__main__.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
import typer
66

7+
from pg2_benchmark.__about__ import __version__
78
from pg2_benchmark.cli.dataset import dataset_app
89
from pg2_benchmark.cli.metric import metric_app
910
from pg2_benchmark.cli.sagemaker import sagemaker_app
10-
from pg2_benchmark.logger import setup_logger
1111
from pg2_benchmark.model_card import ModelCard
1212

1313
app = typer.Typer(
@@ -20,9 +20,6 @@
2020
app.add_typer(metric_app, name="metric", help="Metric operations")
2121
app.add_typer(sagemaker_app, name="sagemaker", help="SageMaker operations")
2222

23-
setup_logger()
24-
logger = logging.getLogger("pg2_benchmark")
25-
2623

2724
class ModelPath:
2825
"""Configuration class for model-related file paths.
@@ -35,28 +32,75 @@ class ModelPath:
3532
"""Default location for model card files relative to model root directory."""
3633

3734

35+
def setup_logger(*, level: int = logging.CRITICAL) -> None:
36+
"""Set up the logger for the application.
37+
38+
Args:
39+
log_level (int): The logging level to set. Defaults to `logging.CRITICAL`.
40+
"""
41+
logger = logging.getLogger("pg2_benchmark")
42+
logger.setLevel(level)
43+
44+
stream_handler = logging.StreamHandler()
45+
stream_handler.setLevel(level)
46+
47+
formatter = logging.Formatter(
48+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
49+
)
50+
stream_handler.setFormatter(formatter)
51+
logger.addHandler(stream_handler)
52+
53+
54+
@app.callback(invoke_without_command=True)
55+
def main(
56+
ctx: typer.Context,
57+
verbose: Annotated[int, typer.Option("--verbose", "-v", count=True)] = 3,
58+
version: Annotated[
59+
bool, typer.Option("--version", help="Show version and exit")
60+
] = False,
61+
) -> None:
62+
"""Main entry point for the CLI.
63+
64+
Args:
65+
ctx (typer.Context): The context for the CLI.
66+
verbose (int): The verbosity level. Use `-v` or `--verbose` to increase
67+
verbosity. Each `-v` increases the verbosity level:
68+
0: CRITICAL, 1: ERROR, 2: WARNING, 3: INFO, 4: DEBUG.
69+
Defaults to 3 (INFO).
70+
version (bool): If `True`, show the package version. Defaults to `False`.
71+
72+
Raises:
73+
typer.Exit: If version is `True`, exits after showing the version.
74+
"""
75+
setup_logger(level=logging.CRITICAL - verbose * 10)
76+
77+
if version:
78+
typer.echo(f"v{__version__}")
79+
raise typer.Exit()
80+
81+
if not ctx.invoked_subcommand:
82+
typer.echo("Welcome to the PG2 Dataset CLI!")
83+
typer.echo("Use --help to see available commands.")
84+
85+
3886
@app.command()
3987
def validate(
40-
model_name: Annotated[
41-
str, typer.Argument(help="Model name defined in the model card")
42-
],
4388
model_path: Annotated[
4489
Path,
4590
typer.Argument(
4691
help="Root path containting the model source code and model card",
4792
exists=True,
4893
dir_okay=True,
49-
file_okay=False,
94+
file_okay=True,
5095
),
5196
],
5297
):
53-
model_card_path = model_path / ModelPath.MODEL_CARD_PATH
98+
logger = logging.getLogger("pg2_benchmark")
5499

55-
if not model_card_path.exists():
56-
logger.error(
57-
f"❌ Model {model_name} does not have a model card at {str(model_card_path)}"
58-
)
59-
raise typer.Exit(1)
100+
if model_path.is_dir():
101+
model_card_path = model_path / ModelPath.MODEL_CARD_PATH
102+
else:
103+
model_card_path = model_path
60104

61105
try:
62106
model_card = ModelCard.from_path(model_card_path)

src/pg2_benchmark/cli/metric.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,6 @@
55
from pycm import ConfusionMatrix
66
from scipy.stats import spearmanr
77

8-
from pg2_benchmark.logger import setup_logger
9-
10-
setup_logger()
11-
logger = logging.getLogger("pg2_benchmark")
12-
138
metric_app = typer.Typer()
149

1510

@@ -18,6 +13,7 @@ def calc(
1813
output_path: str = typer.Option(help="Path to the model output file"),
1914
metric_path: str = typer.Option(help="Path to the metric output file"),
2015
):
16+
logger = logging.getLogger("pg2_benchmark")
2117
logger.info("Calculating metrics...")
2218

2319
df = pd.read_csv(output_path)

src/pg2_benchmark/logger.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

tests/test_validate.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,22 @@ def test_model_card_validation_success(
4949
model_card_file = model_dir / ModelPath.MODEL_CARD_PATH
5050
model_card_file.write_text(valid_model_card_content)
5151

52-
result = runner.invoke(app, ["validate", model_name, str(model_dir)])
52+
result = runner.invoke(app, ["validate", str(model_dir)])
5353

5454
assert result.exit_code == 0
5555
assert "✅ Loaded test_model" in caplog.text
5656
assert "learning_rate" in caplog.text and "batch_size" in caplog.text
5757

5858

59-
def test_model_card_validation_missing_file(tmp_path: Path, runner: CliRunner, caplog):
59+
def test_model_card_validation_missing_file(runner: CliRunner):
6060
"""Test validation when model card file doesn't exist."""
6161
model_name = "nonexistent_model"
62-
model_dir = tmp_path / "models" / model_name
63-
model_dir.mkdir(parents=True)
6462

65-
result = runner.invoke(app, ["validate", model_name, str(model_dir)])
63+
result = runner.invoke(app, ["validate", str(model_name)])
6664

67-
assert result.exit_code == 1
68-
assert "❌ Model nonexistent_model does not have a model card" in caplog.text
65+
# Typer returns exit code 2 for parameter validation errors
66+
assert result.exit_code == 2
67+
assert "does not exist" in result.output
6968

7069

7170
def test_model_card_validation_invalid_content(
@@ -79,7 +78,7 @@ def test_model_card_validation_invalid_content(
7978
model_card_file = model_dir / ModelPath.MODEL_CARD_PATH
8079
model_card_file.write_text(invalid_model_card_content)
8180

82-
result = runner.invoke(app, ["validate", model_name, str(model_dir)])
81+
result = runner.invoke(app, ["validate", str(model_dir)])
8382

8483
assert result.exit_code == 1
8584
assert "❌ Error loading model card" in caplog.text
@@ -94,7 +93,7 @@ def test_model_card_validation_empty_file(tmp_path: Path, runner: CliRunner, cap
9493
model_card_file = model_dir / ModelPath.MODEL_CARD_PATH
9594
model_card_file.write_text("")
9695

97-
result = runner.invoke(app, ["validate", model_name, str(model_dir)])
96+
result = runner.invoke(app, ["validate", str(model_dir)])
9897

9998
assert result.exit_code == 1
10099
assert "❌ Error loading model card" in caplog.text

0 commit comments

Comments
 (0)