Skip to content

Commit 86979bd

Browse files
committed
feat: Implement CLI
- Add single CLI entrypoint `predict` with option for calibration and fine-tuning
1 parent b3df454 commit 86979bd

2 files changed

Lines changed: 119 additions & 23 deletions

File tree

deeplc/__main__.py

Lines changed: 115 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,127 @@
11
"""Main command line interface to DeepLC."""
22

33
import logging
4-
import sys
4+
from pathlib import Path
55

6-
LOGGER = logging.getLogger(__name__)
6+
import click
7+
import pandas as pd
8+
from psm_utils.io import READERS, read_file
9+
from rich.logging import RichHandler
10+
from rich.traceback import install as install_rich_traceback
711

8-
# TODO: Add CLI functionality
12+
import deeplc.core
13+
from deeplc import __version__
914

15+
logger = logging.getLogger(__name__)
1016

11-
def _setup_logging(passed_level):
12-
log_mapping = {
13-
"critical": logging.CRITICAL,
14-
"error": logging.ERROR,
15-
"warning": logging.WARNING,
16-
"info": logging.INFO,
17-
"debug": logging.DEBUG,
18-
}
17+
LOGGING_LEVELS = {
18+
"DEBUG": logging.DEBUG,
19+
"INFO": logging.INFO,
20+
"WARNING": logging.WARNING,
21+
"ERROR": logging.ERROR,
22+
"CRITICAL": logging.CRITICAL,
23+
}
1924

20-
if passed_level.lower() not in log_mapping:
21-
print(
22-
"Invalid log level. Should be one of the following: ",
23-
", ".join(log_mapping.keys()),
24-
)
25-
exit(1)
25+
PSM_FILETYPES = list(READERS.keys())
2626

27+
28+
def _infer_output_name(input_filename: str, output_name: str | None = None) -> Path:
29+
"""Infer output filename from input filename if not provided."""
30+
if output_name:
31+
return Path(output_name)
32+
else:
33+
input_path = Path(input_filename)
34+
return input_path.with_name(input_path.stem + "_deeplc_predictions").with_suffix(".csv")
35+
36+
37+
def _read_psm_file(psms: str, psm_filetype: str | None = None):
38+
"""Read a PSM file and return a PSMList."""
39+
logger.info(f"Reading PSM file: {psms}")
40+
kwargs = {"filetype": psm_filetype} if psm_filetype else {}
41+
return read_file(psms, **kwargs)
42+
43+
44+
def _write_predictions(psm_list, predictions, output_path: Path):
45+
"""Write predictions to a CSV file."""
46+
df = pd.DataFrame(
47+
{
48+
"peptidoform": [str(psm.peptidoform) for psm in psm_list],
49+
"observed_rt": [psm.retention_time for psm in psm_list],
50+
"predicted_rt": predictions,
51+
}
52+
)
53+
logger.info(f"Writing predictions to {output_path}")
54+
df.to_csv(output_path, index=False)
55+
56+
57+
@click.group()
58+
@click.option(
59+
"--logging-level",
60+
"-l",
61+
type=click.Choice(LOGGING_LEVELS.keys()),
62+
default="INFO",
63+
help="Set the logging level.",
64+
)
65+
@click.version_option(version=__version__)
66+
def cli(logging_level, **kwargs):
67+
"""DeepLC: Retention time prediction for peptides carrying any modification."""
68+
install_rich_traceback(show_locals=True)
2769
logging.basicConfig(
28-
stream=sys.stdout,
29-
format="%(asctime)s - %(levelname)s - %(message)s",
70+
format="%(message)s",
3071
datefmt="%Y-%m-%d %H:%M:%S",
31-
level=log_mapping[passed_level.lower()],
72+
level=LOGGING_LEVELS[logging_level],
73+
handlers=[RichHandler(rich_tracebacks=True, show_level=True, show_path=False)],
3274
)
75+
76+
77+
@cli.command()
78+
@click.argument("psms", required=True, type=click.Path(exists=True, dir_okay=False))
79+
@click.option("--psm-filetype", "-t", type=click.Choice(PSM_FILETYPES), default=None)
80+
@click.option(
81+
"--reference", type=click.Path(exists=True, dir_okay=False), default=None,
82+
help="Reference PSM file for calibration or fine-tuning.",
83+
)
84+
@click.option(
85+
"--reference-filetype", "-r", type=click.Choice(PSM_FILETYPES), default=None,
86+
help="File type for the reference file. Inferred if not provided.",
87+
)
88+
@click.option(
89+
"--finetune", is_flag=True, default=False,
90+
help="Fine-tune the model to the reference before predicting. Requires --reference.",
91+
)
92+
@click.option("--output", "-o", type=str, default=None, help="Output file path.")
93+
@click.option("--model", "-m", type=click.Path(exists=True, dir_okay=False), default=None)
94+
def predict(psms, psm_filetype, reference, reference_filetype, finetune, output, model):
95+
"""Predict retention times, optionally calibrating or fine-tuning to a reference."""
96+
if finetune and not reference:
97+
raise click.UsageError("--finetune requires --reference.")
98+
99+
psm_list = _read_psm_file(psms, psm_filetype)
100+
output_path = _infer_output_name(psms, output)
101+
102+
if reference:
103+
psm_list_reference = _read_psm_file(reference, reference_filetype)
104+
if finetune:
105+
predictions = deeplc.core.finetune_and_predict(
106+
psm_list=psm_list,
107+
psm_list_reference=psm_list_reference,
108+
model=model,
109+
)
110+
else:
111+
predictions = deeplc.core.predict_and_calibrate(
112+
psm_list=psm_list,
113+
psm_list_reference=psm_list_reference,
114+
model=model,
115+
)
116+
else:
117+
predictions = deeplc.core.predict(psm_list=psm_list, model=model)
118+
119+
_write_predictions(psm_list, predictions, output_path)
120+
121+
122+
def main():
123+
cli()
124+
125+
126+
if __name__ == "__main__":
127+
main()

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@ dependencies = [
4040
"scikit-learn>=1.2.0,<2",
4141
"psm_utils>=1.5,<2",
4242
"onnx2torch>=1.5,<2",
43+
"click>=8,<9",
44+
"rich>=13,<15",
4345
]
4446

45-
# [project.scripts]
46-
# deeplc = "deeplc.__main__:main"
47-
# deeplc-gui = "deeplc_gui.__main__:main"
47+
[project.scripts]
48+
deeplc = "deeplc.__main__:main"
4849

4950
[project.urls]
5051
GitHub = "https://github.com/compomics/deeplc"

0 commit comments

Comments
 (0)