diff --git a/pyproject.toml b/pyproject.toml index e56dc9db..35821271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ dev = [ "snakefmt==0.11.0", "pytest==8.4.1", ] +ai = [ + "openai==1.88.0", +] [project.urls] "Homepage" = "https://github.com/sunbeam-labs/sunbeam" diff --git a/sunbeam/bfx/decontam.py b/sunbeam/bfx/decontam.py index 8498ac28..0a9a938c 100755 --- a/sunbeam/bfx/decontam.py +++ b/sunbeam/bfx/decontam.py @@ -17,7 +17,7 @@ def get_mapped_reads(fp: str, min_pct_id: float, min_len_frac: float) -> Iterato def _get_pct_identity( - read: Dict[str, Union[int, float, str, Tuple[int, str]]] + read: Dict[str, Union[int, float, str, Tuple[int, str]]], ) -> float: edit_dist = read.get("NM", 0) pct_mm = float(edit_dist) / len(read["SEQ"]) diff --git a/sunbeam/logging.py b/sunbeam/logging.py index 57d72623..0f891efc 100644 --- a/sunbeam/logging.py +++ b/sunbeam/logging.py @@ -2,6 +2,26 @@ from pathlib import Path +class StreamToLogger: + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + self.buffer = "" + + def write(self, message): + if message != "\n": + self.buffer += message + if "\n" in self.buffer: + for line in self.buffer.splitlines(): + self.logger.log(self.level, line.strip()) + self.buffer = "" + + def flush(self): + if self.buffer: + self.logger.log(self.level, self.buffer.strip()) + self.buffer = "" + + class ConditionalLevelFormatter(logging.Formatter): def format(self, record): # For WARNING and above, include "LEVELNAME: message" diff --git a/sunbeam/scripts/run.py b/sunbeam/scripts/run.py index edc39275..f8b9a98f 100644 --- a/sunbeam/scripts/run.py +++ b/sunbeam/scripts/run.py @@ -1,11 +1,57 @@ import argparse +import contextlib import datetime +import logging import os import sys from pathlib import Path from snakemake.cli import main as snakemake_main from sunbeam import __version__ -from sunbeam.logging import get_pipeline_logger +from sunbeam.logging import get_pipeline_logger, StreamToLogger + + +def analyze_run(log: str, logger: logging.Logger, ai: bool) -> None: + """Analyze the run log and provide insights or suggestions.""" + # We could do some rule-based analysis here but I'd rather lean into the AI features and see how far they can take us + if ai: + try: + import openai + except ImportError: # pragma: no cover - this is a soft dependency + logger.error( + "AI analysis requested, but the 'openai' package is not installed. Try `pip install -e sunbeamlib[ai]`.\n" + ) + return + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + logger.error("OPENAI_API_KEY not set; skipping AI analysis.\n") + return + + try: + client = openai.OpenAI(api_key=api_key) + resp = client.chat.completions.create( + model="gpt-4.1-nano", + messages=[ + { + "role": "system", + "content": "You diagnose errors from Sunbeam pipeline runs. If there are problems, suggest possible causes and solutions. Keep the answer short and sweet. If there are relevant file paths for debugging (like log files), mention them.", + }, + { + "role": "user", + "content": f"Sunbeam ran with the following output:\n{log}\n", + }, + ], + max_tokens=1500, + ) + logger.info( + "\n\nAI diagnosis:\n" + + resp.choices[0].message.content + + "\nCheck out the Sunbeam documentation (https://sunbeam.readthedocs.io/en/stable/) and the GitHub issues page (https://github.com/sunbeam-labs/sunbeam/issues) for more information or to open a new issue.\n" + ) + except ( + Exception + ) as exc: # pragma: no cover - network errors are non-deterministic + logger.error(f"AI analysis failed: {exc}\n") def main(argv: list[str] = sys.argv): @@ -27,6 +73,8 @@ def main(argv: list[str] = sys.argv): # You could argue it would make more sense to start this at the actual snakemake call # but this way we can log some relevant setup information that might be useful on post-mortem analysis logger = get_pipeline_logger(log_file) + logger.debug("Sunbeam pipeline logger initialized.") + print(log_file.exists()) snakefile = Path(__file__).parent.parent / "workflow" / "Snakefile" if not snakefile.exists(): @@ -73,10 +121,17 @@ def main(argv: list[str] = sys.argv): logger.info("Running: " + " ".join(snakemake_args)) try: - snakemake_main(snakemake_args) - except Exception as e: - logger.exception("An error occurred while running Sunbeam") - sys.exit(1) + stream_logger = StreamToLogger(logger, level=logging.INFO) + + with contextlib.redirect_stderr(stream_logger): + snakemake_main(snakemake_args) + finally: + # Show all files in log_file directory + print(list(log_file.parent.glob("*"))) + print(log_file) + print(log_file.exists()) + with open(log_file, "r") as f: + analyze_run(f.read(), logger, args.ai) def main_parser(): @@ -128,6 +183,11 @@ def main_parser(): default=__version__, help="The tag to use when pulling docker images for the core pipeline environments, defaults to sunbeam's current version, a good alternative is 'latest' for the latest stable release", ) + parser.add_argument( + "--ai", + action="store_true", + help="Use OpenAI to diagnose failures after the run", + ) parser.add_argument( "--log_file", default=None, diff --git a/sunbeam/workflow/Snakefile b/sunbeam/workflow/Snakefile index 58510bc4..651b8f67 100644 --- a/sunbeam/workflow/Snakefile +++ b/sunbeam/workflow/Snakefile @@ -17,6 +17,7 @@ from sunbeam.project.post import compile_benchmarks logger = get_pipeline_logger() +logger.debug("Sunbeam pipeline starting...") MIN_MEM_MB = int(os.getenv("SUNBEAM_MIN_MEM_MB", 8000)) MIN_RUNTIME = int(os.getenv("SUNBEAM_MIN_RUNTIME", 15)) diff --git a/tests/e2e/test_sunbeam_run.py b/tests/e2e/test_sunbeam_run.py index 6cb32341..1b28ef48 100755 --- a/tests/e2e/test_sunbeam_run.py +++ b/tests/e2e/test_sunbeam_run.py @@ -1,5 +1,7 @@ import pytest import subprocess as sp +import sys +import types from sunbeam.scripts.init import main as Init from sunbeam.scripts.run import main as Run @@ -113,3 +115,48 @@ def test_sunbeam_run_with_target_after_exclude(tmp_path, DATA_DIR, capsys): assert ret.returncode == 0 assert "clean_qc" in ret.stderr.decode("utf-8") assert "filter_reads" not in ret.stderr.decode("utf-8") + + +def test_sunbeam_run_ai_option(tmp_path, monkeypatch, DATA_DIR): + project_dir = tmp_path / "test" + + called = {"flag": False} + + def dummy_create(**kwargs): + called["flag"] = True + return types.SimpleNamespace( + choices=[ + types.SimpleNamespace(message=types.SimpleNamespace(content="analysis")) + ] + ) + + fake_openai = types.SimpleNamespace() + + fake_openai.OpenAI = lambda *args, **kwargs: types.SimpleNamespace( + chat=types.SimpleNamespace( + completions=types.SimpleNamespace(create=dummy_create) + ) + ) + + monkeypatch.setitem(sys.modules, "openai", fake_openai) + monkeypatch.setenv("OPENAI_API_KEY", "token") + + Init( + [ + str(project_dir), + "--data_fp", + str(DATA_DIR / "reads"), + ] + ) + + with pytest.raises(SystemExit): + Run( + [ + "--profile", + str(project_dir), + "--ai", + "-n", + ] + ) + + assert called["flag"]