Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/debsbom/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from uuid import UUID
from urllib.parse import urlparse
from pathlib import Path
import requests

from .generate import Debsbom, SBOMType
from .download import PackageDownloader, PackageResolver, PersistentResolverCache
Expand Down Expand Up @@ -145,8 +146,9 @@ def run(args):
outdir.mkdir(exist_ok=True)
cache = PersistentResolverCache(outdir / ".cache")
resolver = PackageResolver.create(Path(args.bomfile))
sdl = sdlclient.SnapshotDataLake()
downloader = PackageDownloader(args.outdir)
rs = requests.Session()
sdl = sdlclient.SnapshotDataLake(session=rs)
downloader = PackageDownloader(args.outdir, session=rs)

pkgs = []
local_pkgs = []
Expand Down
20 changes: 17 additions & 3 deletions src/debsbom/download/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from functools import reduce
import hashlib
import json
import shutil
import sys
from typing import Generator, Tuple, Type
from pathlib import Path
from urllib.request import urlretrieve
from packageurl import PackageURL
import requests

from ..dpkg import package
from ..snapshot import client as sdlclient
Expand Down Expand Up @@ -61,7 +63,13 @@ def lookup(
if not entry.is_file():
return None
with open(entry, "r") as f:
data = json.load(f)
try:
data = json.load(f)
except json.decoder.JSONDecodeError:
print(
f"cache file {entry.name} ({p.name}@{p.version}) is corrupted", file=sys.stderr
)
return None
return [sdlclient.RemoteFile(**d) for d in data]

def insert(
Expand Down Expand Up @@ -146,10 +154,13 @@ def create(filename: Path) -> Type["PackageResolver"]:


class PackageDownloader:
def __init__(self, outdir: Path | str = "downloads"):
def __init__(
self, outdir: Path | str = "downloads", session: requests.Session = requests.Session()
):
self.dldir = Path(outdir)
self.dldir.mkdir(exist_ok=True)
self.to_download: list["sdlclient.RemoteFile"] = []
self.rs = session

def register(self, files: list["sdlclient.RemoteFile"]):
self.to_download.extend(list(files))
Expand All @@ -174,5 +185,8 @@ def download(self, progress_cb):
else:
print(f"Checksum mismatch on {f.filename}. Download again.", file=sys.stderr)
fdst = target.with_suffix(target.suffix + ".tmp")
urlretrieve(f.downloadurl, fdst)
with self.rs.get(f.downloadurl, stream=True) as r:
r.raise_for_status()
with open(fdst, "wb") as f:
shutil.copyfileobj(r.raw, f)
fdst.rename(target)
18 changes: 11 additions & 7 deletions src/debsbom/snapshot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, sdl, name: str):

def versions(self):
try:
r = requests.get(self.sdl.url + f"/mr/package/{self.name}/")
r = self.sdl.rs.get(self.sdl.url + f"/mr/package/{self.name}/")
except RequestException as e:
raise SnapshotDataLakeError(e)
for v in r.json().get("result", []):
Expand All @@ -55,7 +55,7 @@ def srcfiles(self) -> Generator["RemoteFile", None, None]:
All files associated with the source package
"""
try:
r = requests.get(
r = self.sdl.rs.get(
self.sdl.url + f"/mr/package/{self.name}/{self.version}" "/srcfiles?fileinfo=1"
)
if r.status_code == 404:
Expand All @@ -78,7 +78,7 @@ def binpackages(self) -> Generator["BinaryPackage", None, None]:
All binary packages created from this source package
"""
try:
r = requests.get(
r = self.sdl.rs.get(
self.sdl.url + f"/mr/package/{self.name}/{self.version}" "/binpackages"
)
data = r.json()
Expand Down Expand Up @@ -121,7 +121,7 @@ def files(self, arch: str = None) -> Generator["RemoteFile", None, None]:
# resolve via binary only
api = self.sdl.url + f"/mr/binary/{self.binname}/{self.binversion}/binfiles?fileinfo=1"
try:
r = requests.get(api)
r = self.sdl.rs.get(api)
if r.status_code == 404:
raise NotFoundOnSnapshotError()
data = r.json()
Expand Down Expand Up @@ -170,12 +170,16 @@ class SnapshotDataLake:
Snapshot instance to query against
"""

def __init__(self, url="https://snapshot.debian.org"):
def __init__(
self, url="https://snapshot.debian.org", session: requests.Session = requests.Session()
):
self.url = url
# reuse the same connection for all requests
self.rs = session

def packages(self) -> Generator[Package, None, None]:
try:
r = requests.get(self.url + "/mr/package/")
r = self.rs.get(self.url + "/mr/package/")
data = r.json()
except RequestException as e:
raise SnapshotDataLakeError(e)
Expand All @@ -184,7 +188,7 @@ def packages(self) -> Generator[Package, None, None]:

def fileinfo(self, hash):
try:
r = requests.get(self.url + f"/mr/file/{hash}/info")
r = self.rs.get(self.url + f"/mr/file/{hash}/info")
data = r.json()
except RequestException as e:
raise SnapshotDataLakeError(e)
Expand Down