Skip to content

Commit 4ea29bd

Browse files
committed
Replace 'aws s3 ls' shell-out in dataset.py with boto3
1 parent d594521 commit 4ea29bd

5 files changed

Lines changed: 593 additions & 60 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,6 @@ cython_debug/
163163
*.wav
164164
wandb/*
165165
*.out
166-
test_*
166+
test_*
167+
# macOS
168+
.DS_Store

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ The following properties are defined in the top level of the model configuration
171171
## Dataset config
172172
`stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)
173173

174+
## S3-compatible storage (Backblaze B2)
175+
The S3 dataset loader uses `boto3`, which ships in the `train` extra. If you installed without that extra, add it with `pip install boto3` (or `pip install "stable-audio-tools[train]"`).
176+
177+
The loader honors the `AWS_ENDPOINT_URL` environment variable, so you can point it at any S3-compatible host without changing the dataset config.
178+
179+
Example for [Backblaze B2](https://github.com/backblaze-labs/):
180+
```bash
181+
export AWS_ENDPOINT_URL=https://s3.us-west-004.backblazeb2.com
182+
export AWS_ACCESS_KEY_ID=<B2 application key ID>
183+
export AWS_SECRET_ACCESS_KEY=<B2 application key>
184+
```
185+
186+
When `AWS_ENDPOINT_URL` is unset, the loader uses default AWS S3, so existing setups are unaffected.
187+
174188
# Todo
175189
- [ ] Add troubleshooting section
176190
- [ ] Add contribution guidelines

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
[project.optional-dependencies]
3838
train = [
3939
"auraloss==0.4.0",
40+
"boto3>=1.26",
4041
"descript-audio-codec==1.0.0",
4142
"encodec==0.1.1",
4243
"inf-cl",

stable_audio_tools/data/dataset.py

Lines changed: 181 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
import importlib
2-
import numpy as np
32
import io
43
import json
54
import os
6-
import dill
75
import posixpath
86
import random
9-
import re
10-
import subprocess
7+
import shlex
118
import time
9+
from functools import lru_cache
10+
from importlib.metadata import PackageNotFoundError, version
11+
from os import path
12+
from typing import Callable, List, Optional
13+
14+
import dill
15+
import numpy as np
1216
import torch
1317
import torchaudio
1418
import webdataset as wds
15-
16-
from os import path
1719
from torch import nn
1820
from torchaudio import transforms as T
19-
from typing import Optional, Callable, List
2021

21-
from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm, strip_trailing_silence
22+
from .utils import Mono, PadCrop_Normalized_T, PhaseFlipper, Stereo, VolumeNorm, strip_trailing_silence
2223

2324
AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
2425

@@ -483,65 +484,176 @@ def __getitem__(self, idx):
483484

484485
# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
485486

486-
def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
487+
@lru_cache(maxsize=1)
488+
def _user_agent():
489+
"Base ``stable-audio-tools/<version>`` token, looked up once on first use."
490+
try:
491+
ver = version("stable-audio-tools")
492+
except PackageNotFoundError: # source/editable checkout without dist metadata
493+
ver = "dev"
494+
return f"stable-audio-tools/{ver}"
495+
496+
497+
def _build_user_agent_extra(user_agent_extra=None):
498+
"""``stable-audio-tools/<version>`` with any caller- or env-provided
499+
(``STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA``) value appended, not replacing it.
500+
Pass ``user_agent_extra=""`` to suppress the env value and use the base only."""
501+
base = _user_agent()
502+
if user_agent_extra is None:
503+
user_agent_extra = os.environ.get("STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA")
504+
return f"{base} {user_agent_extra}" if user_agent_extra else base
505+
506+
507+
@lru_cache(maxsize=32)
508+
def _build_s3_client(profile, endpoint_url, user_agent_extra):
509+
try:
510+
import boto3 # local import so boto3 is only required when S3 is used
511+
from botocore.config import Config
512+
except ModuleNotFoundError as e:
513+
raise ImportError(
514+
"S3 dataset access requires boto3. Install it with "
515+
"'pip install boto3' or 'pip install stable-audio-tools[train]'."
516+
) from e
517+
518+
session = boto3.Session(profile_name=profile) if profile else boto3.Session()
519+
return session.client(
520+
"s3",
521+
endpoint_url=endpoint_url,
522+
config=Config(user_agent_extra=user_agent_extra),
523+
)
524+
525+
526+
def _get_s3_client(profile=None, user_agent_extra=None):
527+
"""
528+
Build (and reuse) a boto3 S3 client. Honors AWS_ENDPOINT_URL when set so the
529+
same code path works against any S3-compatible endpoint (AWS S3 by default;
530+
set AWS_ENDPOINT_URL to a Backblaze B2 endpoint to point it at B2). When the
531+
env var is unset, behavior matches the default AWS client.
532+
533+
Clients are cached per (profile, endpoint, user-agent) so listing and
534+
presigning share one client instead of building a new one on each call.
487535
"""
488-
Returns a list of full S3 paths to files in a given S3 bucket and directory path.
536+
endpoint_url = os.environ.get("AWS_ENDPOINT_URL") or None
537+
return _build_s3_client(
538+
profile, endpoint_url, _build_user_agent_extra(user_agent_extra)
539+
)
540+
541+
542+
def _parse_s3_url(url):
543+
"Split an ``s3://bucket/key`` URL into (bucket, key). Raises ValueError otherwise."
544+
if not url.startswith("s3://"):
545+
raise ValueError(f"expected an s3:// URL, got: {url!r}")
546+
bucket, _, key = url[len("s3://"):].partition("/")
547+
if not bucket.strip():
548+
raise ValueError(f"s3:// URL is missing a bucket name: {url!r}")
549+
return bucket, key
550+
551+
552+
def get_s3_contents(
553+
dataset_path,
554+
s3_url_prefix=None,
555+
filter='', # deprecated alias for filter_str; kept for backwards compat
556+
recursive=True,
557+
debug=False,
558+
profile=None,
559+
relative=False,
560+
filter_str=None, # preferred substring filter
561+
):
489562
"""
563+
Returns a list of objects in a given S3 bucket and directory path.
564+
565+
By default returns full ``s3://bucket/key`` paths (backwards compatible with
566+
the previous implementation). Pass ``relative=True`` to get keys relative to
567+
``dataset_path`` instead. Uses boto3 directly so it works against any
568+
S3-compatible endpoint when ``AWS_ENDPOINT_URL`` is set.
569+
570+
``filter_str`` is the preferred substring-filter argument; ``filter`` is kept
571+
as a backwards-compatible alias.
572+
"""
573+
if filter_str is None:
574+
filter_str = filter
490575
# Ensure dataset_path ends with a trailing slash
491576
if dataset_path != '' and not dataset_path.endswith('/'):
492577
dataset_path += '/'
493-
# Use posixpath to construct the S3 URL path
578+
# Use posixpath to construct the S3 URL path (e.g. "s3://bucket/prefix/")
494579
bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
495-
# Construct the `aws s3 ls` command
496-
cmd = ['aws', 's3', 'ls', bucket_path]
497580

498-
if profile is not None:
499-
cmd.extend(['--profile', profile])
581+
bucket, prefix = _parse_s3_url(bucket_path)
582+
583+
s3 = _get_s3_client(profile=profile)
584+
paginator = s3.get_paginator("list_objects_v2")
585+
list_kwargs = {"Bucket": bucket, "Prefix": prefix}
586+
if not recursive:
587+
list_kwargs["Delimiter"] = "/"
588+
589+
keys = []
590+
for page in paginator.paginate(**list_kwargs):
591+
for obj in page.get("Contents", []) or []:
592+
key = obj.get("Key", "")
593+
if not key or key.endswith("/"):
594+
continue
595+
keys.append(key)
500596

501-
if recursive:
502-
# Add the --recursive flag if requested
503-
cmd.append('--recursive')
504-
505-
# Run the `aws s3 ls` command and capture the output
506-
run_ls = subprocess.run(cmd, capture_output=True, check=True)
507-
# Split the output into lines and strip whitespace from each line
508-
contents = run_ls.stdout.decode('utf-8').split('\n')
509-
contents = [x.strip() for x in contents if x]
510-
# Remove the timestamp from lines that begin with a timestamp
511-
contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
512-
if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
513-
# Construct a full S3 path for each file in the contents list
514-
contents = [posixpath.join(s3_url_prefix or '', x)
515-
for x in contents if not x.endswith('/')]
516597
# Apply the filter, if specified
517-
if filter:
518-
contents = [x for x in contents if filter in x]
519-
# Remove redundant directory names in the S3 URL
520-
if recursive:
521-
# Get the main directory name from the S3 URL
522-
main_dir = "/".join(bucket_path.split('/')[3:])
523-
# Remove the redundant directory names from each file path
524-
contents = [x.replace(f'{main_dir}', '').replace(
525-
'//', '/') for x in contents]
526-
# Print debugging information, if requested
598+
if filter_str:
599+
keys = [k for k in keys if filter_str in k]
600+
601+
if relative:
602+
# Strip the bucket-level prefix so keys are relative to dataset_path.
603+
if prefix:
604+
keys = [k[len(prefix):] if k.startswith(prefix) else k for k in keys]
605+
keys = [k.lstrip('/') for k in keys]
606+
contents = keys
607+
else:
608+
# Backwards-compatible default: full s3://bucket/key paths.
609+
contents = [f"s3://{bucket}/{k}" for k in keys]
610+
527611
if debug:
528612
print("contents = \n", contents)
529-
# Return the list of S3 paths to files
613+
530614
return contents
531615

532616

617+
# 7 days (SigV4 max) so shard URLs outlast long training runs. Override per
618+
# call or via STABLE_AUDIO_TOOLS_S3_PRESIGN_EXPIRY.
619+
_DEFAULT_PRESIGN_EXPIRY_SECONDS = 7 * 24 * 3600
620+
621+
533622
def get_all_s3_urls(
534-
names=[], # list of all valid [LAION AudioDataset] dataset names
535-
# list of subsets you want from those datasets, e.g. ['train','valid']
536-
subsets=[''],
623+
names=None, # list of [LAION AudioDataset] dataset names; None -> []
624+
subsets=None, # list of subsets, e.g. ['train','valid']; None -> ['']
537625
s3_url_prefix=None, # prefix for those dataset names
538626
recursive=True, # recursively list all tar files in all subdirs
539627
filter_str='tar', # only grab files with this substring
540628
# print debugging info -- note: info displayed likely to change at dev's whims
541629
debug=False,
542-
profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
630+
profiles=None, # dict of profiles per name, e.g. {'dataset1': 'profile1'}; None -> {}
631+
presign_expiry_seconds=None, # presigned-URL lifetime; None -> env var or default
543632
):
544-
"get urls of shards (tar files) for multiple datasets in one s3 bucket"
633+
"""Get urls of shards (tar files) for multiple datasets in one s3 bucket.
634+
635+
Shards are fetched via presigned URLs handed to ``curl``. The URL carries
636+
short-lived credentials on the command line, so on shared/multi-tenant hosts
637+
it can be visible to other local users (e.g. via ``ps``); keep
638+
``presign_expiry_seconds`` short in those environments.
639+
"""
640+
names = [] if names is None else names
641+
subsets = [''] if subsets is None else subsets
642+
profiles = profiles or {}
643+
if presign_expiry_seconds is None:
644+
presign_expiry_seconds = os.environ.get(
645+
"STABLE_AUDIO_TOOLS_S3_PRESIGN_EXPIRY", _DEFAULT_PRESIGN_EXPIRY_SECONDS)
646+
try:
647+
presign_expiry_seconds = int(presign_expiry_seconds)
648+
except (TypeError, ValueError):
649+
raise ValueError(
650+
"presign_expiry_seconds (or STABLE_AUDIO_TOOLS_S3_PRESIGN_EXPIRY) must be "
651+
f"an integer number of seconds, got: {presign_expiry_seconds!r}"
652+
)
653+
if presign_expiry_seconds <= 0:
654+
raise ValueError(
655+
f"presign_expiry_seconds must be positive, got: {presign_expiry_seconds}"
656+
)
545657
urls = []
546658
for name in names:
547659
# If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
@@ -559,23 +671,33 @@ def get_all_s3_urls(
559671
# Get the list of tar files in the current subset directory
560672
profile = profiles.get(name, None)
561673
tar_list = get_s3_contents(
562-
subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
674+
subset_str, s3_url_prefix=None, recursive=recursive, filter_str=filter_str, debug=debug, profile=profile, relative=True)
675+
# Reuse the cached S3 client (shared with get_s3_contents) for presigning.
676+
s3_client = _get_s3_client(profile=profile)
563677
for tar in tar_list:
564-
# Escape spaces and parentheses in the tar filename for use in the shell command
565-
tar = tar.replace(" ", "\ ").replace(
566-
"(", "\(").replace(")", "\)")
567-
# Construct the S3 path to the current tar file
568-
s3_path = posixpath.join(name, subset, tar) + " -"
569-
# Construct the AWS CLI command to download the current tar file
678+
# Construct the full s3:// URL for the current tar file.
570679
if s3_url_prefix is None:
571-
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
680+
full_s3_url = posixpath.join(name, subset, tar)
572681
else:
573-
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
574-
if profiles.get(name):
575-
request_str += f" --profile {profiles.get(name)}"
682+
full_s3_url = posixpath.join(s3_url_prefix, name, subset, tar)
683+
684+
bucket, key = _parse_s3_url(full_s3_url)
685+
686+
# Presigned GET URL works against AWS and any S3-compatible
687+
# endpoint when AWS_ENDPOINT_URL is set. Expiry is configurable
688+
# so long training runs do not outlive their shard URLs.
689+
presigned = s3_client.generate_presigned_url(
690+
"get_object",
691+
Params={"Bucket": bucket, "Key": key},
692+
ExpiresIn=presign_expiry_seconds,
693+
)
694+
# --retry restores the transient-failure resilience the AWS CLI
695+
# had; shlex.quote keeps URL contents safe in the pipe: shell command.
696+
request_str = f"pipe:curl -fsSL --retry 5 {shlex.quote(presigned)}"
576697
if debug:
577-
print("request_str = ", request_str)
578-
# Add the constructed URL to the list of URLs
698+
# Strip the signed query string so signatures are not logged.
699+
redacted = presigned.split("?", 1)[0] + "?<redacted>"
700+
print("request_str = pipe:curl -fsSL --retry 5", shlex.quote(redacted))
579701
urls.append(request_str)
580702
return urls
581703

0 commit comments

Comments
 (0)