Skip to content

Commit e73bdf4

Browse files
committed
Replace 'aws s3 ls' shell-out in dataset.py with boto3
1 parent d594521 commit e73bdf4

7 files changed

Lines changed: 729 additions & 109 deletions

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,7 @@ cython_debug/
163163
*.wav
164164
wandb/*
165165
*.out
166-
test_*
166+
test_*
167+
!tests/test_*.py
168+
# macOS
169+
.DS_Store

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,20 @@ The following properties are defined in the top level of the model configuration
171171
## Dataset config
172172
`stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)
173173

174+
## S3-compatible storage (Backblaze B2)
175+
The S3 dataset loader uses `boto3`, which ships in the `train` extra. If you installed without that extra, add it with `pip install boto3` (or `pip install "stable-audio-tools[train]"`).
176+
177+
The loader honors the `AWS_ENDPOINT_URL` environment variable, so you can point it at any S3-compatible host without changing the dataset config.
178+
179+
Example for [Backblaze B2](https://github.com/backblaze-labs/):
180+
```bash
181+
export AWS_ENDPOINT_URL=https://s3.us-west-004.backblazeb2.com
182+
export AWS_ACCESS_KEY_ID=<B2 application key ID>
183+
export AWS_SECRET_ACCESS_KEY=<B2 application key>
184+
```
185+
186+
When `AWS_ENDPOINT_URL` is unset, the loader uses default AWS S3, so existing setups are unaffected.
187+
174188
# Todo
175189
- [ ] Add troubleshooting section
176190
- [ ] Add contribution guidelines

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
[project.optional-dependencies]
3838
train = [
3939
"auraloss==0.4.0",
40+
"boto3>=1.26,<2",
4041
"descript-audio-codec==1.0.0",
4142
"encodec==0.1.1",
4243
"inf-cl",

stable_audio_tools/data/dataset.py

Lines changed: 23 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import importlib
2-
import numpy as np
32
import io
43
import json
54
import os
6-
import dill
7-
import posixpath
85
import random
9-
import re
10-
import subprocess
116
import time
7+
from os import path
8+
from typing import Callable, List, Optional
9+
10+
import dill
11+
import numpy as np
1212
import torch
1313
import torchaudio
1414
import webdataset as wds
15-
16-
from os import path
1715
from torch import nn
1816
from torchaudio import transforms as T
19-
from typing import Optional, Callable, List
2017

21-
from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm, strip_trailing_silence
18+
from .utils import Mono, PadCrop_Normalized_T, PhaseFlipper, Stereo, VolumeNorm, strip_trailing_silence
2219

2320
AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
2421

@@ -481,105 +478,23 @@ def __getitem__(self, idx):
481478
print(f'Couldn\'t load file {latent_filename}: {e}')
482479
return self[random.randrange(len(self))]
483480

484-
# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
485-
486-
def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
487-
"""
488-
Returns a list of full S3 paths to files in a given S3 bucket and directory path.
489-
"""
490-
# Ensure dataset_path ends with a trailing slash
491-
if dataset_path != '' and not dataset_path.endswith('/'):
492-
dataset_path += '/'
493-
# Use posixpath to construct the S3 URL path
494-
bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
495-
# Construct the `aws s3 ls` command
496-
cmd = ['aws', 's3', 'ls', bucket_path]
497-
498-
if profile is not None:
499-
cmd.extend(['--profile', profile])
500-
501-
if recursive:
502-
# Add the --recursive flag if requested
503-
cmd.append('--recursive')
504-
505-
# Run the `aws s3 ls` command and capture the output
506-
run_ls = subprocess.run(cmd, capture_output=True, check=True)
507-
# Split the output into lines and strip whitespace from each line
508-
contents = run_ls.stdout.decode('utf-8').split('\n')
509-
contents = [x.strip() for x in contents if x]
510-
# Remove the timestamp from lines that begin with a timestamp
511-
contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
512-
if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
513-
# Construct a full S3 path for each file in the contents list
514-
contents = [posixpath.join(s3_url_prefix or '', x)
515-
for x in contents if not x.endswith('/')]
516-
# Apply the filter, if specified
517-
if filter:
518-
contents = [x for x in contents if filter in x]
519-
# Remove redundant directory names in the S3 URL
520-
if recursive:
521-
# Get the main directory name from the S3 URL
522-
main_dir = "/".join(bucket_path.split('/')[3:])
523-
# Remove the redundant directory names from each file path
524-
contents = [x.replace(f'{main_dir}', '').replace(
525-
'//', '/') for x in contents]
526-
# Print debugging information, if requested
527-
if debug:
528-
print("contents = \n", contents)
529-
# Return the list of S3 paths to files
530-
return contents
531-
532-
533-
def get_all_s3_urls(
534-
names=[], # list of all valid [LAION AudioDataset] dataset names
535-
# list of subsets you want from those datasets, e.g. ['train','valid']
536-
subsets=[''],
537-
s3_url_prefix=None, # prefix for those dataset names
538-
recursive=True, # recursively list all tar files in all subdirs
539-
filter_str='tar', # only grab files with this substring
540-
# print debugging info -- note: info displayed likely to change at dev's whims
541-
debug=False,
542-
profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
543-
):
544-
"get urls of shards (tar files) for multiple datasets in one s3 bucket"
545-
urls = []
546-
for name in names:
547-
# If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
548-
if s3_url_prefix is None:
549-
contents_str = name
550-
else:
551-
# Construct the S3 path using the s3_url_prefix and the current name value
552-
contents_str = posixpath.join(s3_url_prefix, name)
553-
if debug:
554-
print(f"get_all_s3_urls: {contents_str}:")
555-
for subset in subsets:
556-
subset_str = posixpath.join(contents_str, subset)
557-
if debug:
558-
print(f"subset_str = {subset_str}")
559-
# Get the list of tar files in the current subset directory
560-
profile = profiles.get(name, None)
561-
tar_list = get_s3_contents(
562-
subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
563-
for tar in tar_list:
564-
# Escape spaces and parentheses in the tar filename for use in the shell command
565-
tar = tar.replace(" ", "\ ").replace(
566-
"(", "\(").replace(")", "\)")
567-
# Construct the S3 path to the current tar file
568-
s3_path = posixpath.join(name, subset, tar) + " -"
569-
# Construct the AWS CLI command to download the current tar file
570-
if s3_url_prefix is None:
571-
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
572-
else:
573-
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
574-
if profiles.get(name):
575-
request_str += f" --profile {profiles.get(name)}"
576-
if debug:
577-
print("request_str = ", request_str)
578-
# Add the constructed URL to the list of URLs
579-
urls.append(request_str)
580-
return urls
581-
582-
481+
# S3 helpers live in the import-light s3_utils module (no torch) so it can also
482+
# run as a `pipe:` subprocess that streams individual shards. Re-exported here
483+
# for backwards compatibility.
484+
from .s3_utils import ( # noqa: E402
485+
_build_s3_client,
486+
_build_user_agent_extra,
487+
_get_s3_client,
488+
_parse_s3_url,
489+
_user_agent,
490+
get_all_s3_urls,
491+
get_s3_contents,
492+
shard_pipe_command,
493+
stream_object,
494+
)
495+
496+
497+
# WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
583498
def log_and_continue(exn):
584499
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
585500
print(f"Handling webdataset error ({repr(exn)}). Ignoring.")

0 commit comments

Comments
 (0)