diff --git a/modelaudit/cli.py b/modelaudit/cli.py index 499c4de80..f76643bae 100644 --- a/modelaudit/cli.py +++ b/modelaudit/cli.py @@ -952,6 +952,18 @@ def scan_command( # This prevents FileNotFoundError when URLs are downloaded to local paths scanned_paths: list[str] = [] + def _track_streaming_paths_for_sbom(streaming_result: ModelAuditResultModel, fallback_path: str) -> None: + """Track concrete streamed artifact paths so SBOM includes all scanned components.""" + added_path = False + for asset in streaming_result.assets: + if asset.path: + scanned_paths.append(asset.path) + added_path = True + + # Fallback keeps previous behavior if no assets were recorded. + if not added_path: + scanned_paths.append(fallback_path) + # Track temporary directories to clean up after SBOM generation temp_dirs_to_cleanup: list[str] = [] @@ -1102,6 +1114,9 @@ def scan_command( # Merge streaming results into audit_result audit_result.aggregate_scan_result(streaming_result.model_dump()) + # Track streamed artifact paths so SBOM includes all components. + _track_streaming_paths_for_sbom(streaming_result, path) + # Record download/scan completion for streaming mode download_duration = time.time() - download_start record_download_completed("huggingface", download_duration, 0, path) @@ -1203,6 +1218,9 @@ def scan_command( cache_dir=final_cache_dir, ) + # Track streamed artifact paths so SBOM includes all components. + _track_streaming_paths_for_sbom(streaming_result, path) + # Merge streaming results audit_result.aggregate_scan_result(streaming_result.model_dump()) @@ -1356,6 +1374,9 @@ def scan_command( cache_dir=final_cache_dir, ) + # Track streamed artifact paths so SBOM includes all components. + _track_streaming_paths_for_sbom(streaming_result, path) + # Merge streaming results audit_result.aggregate_scan_result(streaming_result.model_dump()) @@ -1685,8 +1706,8 @@ def enhanced_progress_callback(message, percentage): elif show_styled_output: click.echo(style_text("✅ Streaming scan complete", fg="green", bold=True)) - # Track the scanned path for SBOM - scanned_paths.append(actual_path) + # Track streamed artifact paths so SBOM includes all components. + _track_streaming_paths_for_sbom(streaming_result, actual_path) # Skip normal scanning flow - continue to next path continue diff --git a/tests/test_cli.py b/tests/test_cli.py index e4277821d..2f227077b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,7 +1,8 @@ import json import os import re -from unittest.mock import patch +from pathlib import Path +from unittest.mock import Mock, patch import pytest from click.testing import CliRunner @@ -775,6 +776,63 @@ def file_generator(): pytest.fail("Output is not valid JSON") +@patch("modelaudit.cli.is_huggingface_url") +@patch("modelaudit.utils.sources.huggingface.download_model_streaming") +@patch("modelaudit.core.scan_model_streaming") +def test_scan_huggingface_streaming_sbom_contains_all_components( + mock_scan_streaming: Mock, mock_download_streaming: Mock, mock_is_hf_url: Mock, tmp_path: Path +) -> None: + """Regression test for issue #671: --stream should still produce full SBOM components.""" + mock_is_hf_url.return_value = True + + # The generator itself is not consumed in this test because scan_model_streaming is mocked. + def file_generator(): + yield (tmp_path / "model-00001-of-00002.safetensors", False) + yield (tmp_path / "model-00002-of-00002.safetensors", True) + + mock_download_streaming.return_value = file_generator() + + streamed_assets = [ + { + "path": str(tmp_path / "model-00001-of-00002.safetensors"), + "type": "safetensors", + "size": 123, + }, + { + "path": str(tmp_path / "model-00002-of-00002.safetensors"), + "type": "safetensors", + "size": 456, + }, + ] + mock_scan_streaming.return_value = create_mock_scan_result( + bytes_scanned=579, + files_scanned=2, + assets=streamed_assets, + ) + + sbom_file = tmp_path / "streaming_sbom.json" + runner = CliRunner() + result = runner.invoke( + cli, + [ + "scan", + "--stream", + "--sbom", + str(sbom_file), + "https://huggingface.co/test/model", + ], + ) + + assert result.exit_code == 0 + assert sbom_file.exists() + + sbom_json = json.loads(sbom_file.read_text(encoding="utf-8")) + component_names = {component["name"] for component in sbom_json.get("components", [])} + + assert "model-00001-of-00002.safetensors" in component_names + assert "model-00002-of-00002.safetensors" in component_names + + @patch("modelaudit.cli.is_huggingface_url") @patch("modelaudit.utils.sources.huggingface.download_model_streaming") @patch("modelaudit.core.scan_model_streaming")