Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions modelaudit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,18 @@ def scan_command(
# This prevents FileNotFoundError when URLs are downloaded to local paths
scanned_paths: list[str] = []

def _track_streaming_paths_for_sbom(streaming_result: ModelAuditResultModel, fallback_path: str) -> None:
"""Track concrete streamed artifact paths so SBOM includes all scanned components."""
added_path = False
for asset in streaming_result.assets:
if asset.path:
scanned_paths.append(asset.path)
added_path = True

# Fallback keeps previous behavior if no assets were recorded.
if not added_path:
scanned_paths.append(fallback_path)

# Track temporary directories to clean up after SBOM generation
temp_dirs_to_cleanup: list[str] = []

Expand Down Expand Up @@ -1102,6 +1114,9 @@ def scan_command(
# Merge streaming results into audit_result
audit_result.aggregate_scan_result(streaming_result.model_dump())

# Track streamed artifact paths so SBOM includes all components.
_track_streaming_paths_for_sbom(streaming_result, path)

# Record download/scan completion for streaming mode
download_duration = time.time() - download_start
record_download_completed("huggingface", download_duration, 0, path)
Expand Down Expand Up @@ -1685,8 +1700,8 @@ def enhanced_progress_callback(message, percentage):
elif show_styled_output:
click.echo(style_text("✅ Streaming scan complete", fg="green", bold=True))

# Track the scanned path for SBOM
scanned_paths.append(actual_path)
# Track streamed artifact paths so SBOM includes all components.
_track_streaming_paths_for_sbom(streaming_result, actual_path)

# Skip normal scanning flow - continue to next path
continue
Expand Down
57 changes: 57 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,63 @@ def file_generator():
pytest.fail("Output is not valid JSON")


@patch("modelaudit.cli.is_huggingface_url")
@patch("modelaudit.utils.sources.huggingface.download_model_streaming")
@patch("modelaudit.core.scan_model_streaming")
def test_scan_huggingface_streaming_sbom_contains_all_components(
mock_scan_streaming, mock_download_streaming, mock_is_hf_url, tmp_path
):
"""Regression test for issue #671: --stream should still produce full SBOM components."""
mock_is_hf_url.return_value = True

# The generator itself is not consumed in this test because scan_model_streaming is mocked.
def file_generator():
yield (tmp_path / "model-00001-of-00002.safetensors", False)
yield (tmp_path / "model-00002-of-00002.safetensors", True)

mock_download_streaming.return_value = file_generator()

streamed_assets = [
{
"path": str(tmp_path / "model-00001-of-00002.safetensors"),
"type": "safetensors",
"size": 123,
},
{
"path": str(tmp_path / "model-00002-of-00002.safetensors"),
"type": "safetensors",
"size": 456,
},
]
mock_scan_streaming.return_value = create_mock_scan_result(
bytes_scanned=579,
files_scanned=2,
assets=streamed_assets,
)

sbom_file = tmp_path / "streaming_sbom.json"
runner = CliRunner()
result = runner.invoke(
cli,
[
"scan",
"--stream",
"--sbom",
str(sbom_file),
"hf://openai-community/gpt2",
],
)

assert result.exit_code == 0
assert sbom_file.exists()

sbom_json = json.loads(sbom_file.read_text(encoding="utf-8"))
component_names = {component["name"] for component in sbom_json.get("components", [])}

assert "model-00001-of-00002.safetensors" in component_names
assert "model-00002-of-00002.safetensors" in component_names


@patch("modelaudit.cli.is_huggingface_url")
@patch("modelaudit.utils.sources.huggingface.download_model_streaming")
@patch("modelaudit.core.scan_model_streaming")
Expand Down