11import importlib
2- import numpy as np
32import io
43import json
54import os
6- import dill
75import posixpath
86import random
9- import re
10- import subprocess
7+ import shlex
118import time
9+ from functools import lru_cache
10+ from importlib .metadata import PackageNotFoundError , version
11+ from os import path
12+ from typing import Callable , List , Optional
13+
14+ import dill
15+ import numpy as np
1216import torch
1317import torchaudio
1418import webdataset as wds
15-
16- from os import path
1719from torch import nn
1820from torchaudio import transforms as T
19- from typing import Optional , Callable , List
2021
21- from .utils import Stereo , Mono , PhaseFlipper , PadCrop_Normalized_T , VolumeNorm , strip_trailing_silence
22+ from .utils import Mono , PadCrop_Normalized_T , PhaseFlipper , Stereo , VolumeNorm , strip_trailing_silence
2223
2324AUDIO_KEYS = ("flac" , "wav" , "mp3" , "m4a" , "ogg" , "opus" )
2425
@@ -483,65 +484,176 @@ def __getitem__(self, idx):
483484
484485# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
485486
486- def get_s3_contents (dataset_path , s3_url_prefix = None , filter = '' , recursive = True , debug = False , profile = None ):
487+ @lru_cache (maxsize = 1 )
488+ def _user_agent ():
489+ "Base ``stable-audio-tools/<version>`` token, looked up once on first use."
490+ try :
491+ ver = version ("stable-audio-tools" )
492+ except PackageNotFoundError : # source/editable checkout without dist metadata
493+ ver = "dev"
494+ return f"stable-audio-tools/{ ver } "
495+
496+
497+ def _build_user_agent_extra (user_agent_extra = None ):
498+ """``stable-audio-tools/<version>`` with any caller- or env-provided
499+ (``STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA``) value appended, not replacing it.
500+ Pass ``user_agent_extra=""`` to suppress the env value and use the base only."""
501+ base = _user_agent ()
502+ if user_agent_extra is None :
503+ user_agent_extra = os .environ .get ("STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA" )
504+ return f"{ base } { user_agent_extra } " if user_agent_extra else base
505+
506+
507+ @lru_cache (maxsize = 32 )
508+ def _build_s3_client (profile , endpoint_url , user_agent_extra ):
509+ try :
510+ import boto3 # local import so boto3 is only required when S3 is used
511+ from botocore .config import Config
512+ except ModuleNotFoundError as e :
513+ raise ImportError (
514+ "S3 dataset access requires boto3. Install it with "
515+ "'pip install boto3' or 'pip install stable-audio-tools[train]'."
516+ ) from e
517+
518+ session = boto3 .Session (profile_name = profile ) if profile else boto3 .Session ()
519+ return session .client (
520+ "s3" ,
521+ endpoint_url = endpoint_url ,
522+ config = Config (user_agent_extra = user_agent_extra ),
523+ )
524+
525+
526+ def _get_s3_client (profile = None , user_agent_extra = None ):
527+ """
528+ Build (and reuse) a boto3 S3 client. Honors AWS_ENDPOINT_URL when set so the
529+ same code path works against any S3-compatible endpoint (AWS S3 by default;
530+ set AWS_ENDPOINT_URL to a Backblaze B2 endpoint to point it at B2). When the
531+ env var is unset, behavior matches the default AWS client.
532+
533+ Clients are cached per (profile, endpoint, user-agent) so listing and
534+ presigning share one client instead of building a new one on each call.
487535 """
488- Returns a list of full S3 paths to files in a given S3 bucket and directory path.
536+ endpoint_url = os .environ .get ("AWS_ENDPOINT_URL" ) or None
537+ return _build_s3_client (
538+ profile , endpoint_url , _build_user_agent_extra (user_agent_extra )
539+ )
540+
541+
542+ def _parse_s3_url (url ):
543+ "Split an ``s3://bucket/key`` URL into (bucket, key). Raises ValueError otherwise."
544+ if not url .startswith ("s3://" ):
545+ raise ValueError (f"expected an s3:// URL, got: { url !r} " )
546+ bucket , _ , key = url [len ("s3://" ):].partition ("/" )
547+ if not bucket .strip ():
548+ raise ValueError (f"s3:// URL is missing a bucket name: { url !r} " )
549+ return bucket , key
550+
551+
552+ def get_s3_contents (
553+ dataset_path ,
554+ s3_url_prefix = None ,
555+ filter = '' , # deprecated alias for filter_str; kept for backwards compat
556+ recursive = True ,
557+ debug = False ,
558+ profile = None ,
559+ relative = False ,
560+ filter_str = None , # preferred substring filter
561+ ):
489562 """
563+ Returns a list of objects in a given S3 bucket and directory path.
564+
565+ By default returns full ``s3://bucket/key`` paths (backwards compatible with
566+ the previous implementation). Pass ``relative=True`` to get keys relative to
567+ ``dataset_path`` instead. Uses boto3 directly so it works against any
568+ S3-compatible endpoint when ``AWS_ENDPOINT_URL`` is set.
569+
570+ ``filter_str`` is the preferred substring-filter argument; ``filter`` is kept
571+ as a backwards-compatible alias.
572+ """
573+ if filter_str is None :
574+ filter_str = filter
490575 # Ensure dataset_path ends with a trailing slash
491576 if dataset_path != '' and not dataset_path .endswith ('/' ):
492577 dataset_path += '/'
493- # Use posixpath to construct the S3 URL path
578+ # Use posixpath to construct the S3 URL path (e.g. "s3://bucket/prefix/")
494579 bucket_path = posixpath .join (s3_url_prefix or '' , dataset_path )
495- # Construct the `aws s3 ls` command
496- cmd = ['aws' , 's3' , 'ls' , bucket_path ]
497580
498- if profile is not None :
499- cmd .extend (['--profile' , profile ])
581+ bucket , prefix = _parse_s3_url (bucket_path )
582+
583+ s3 = _get_s3_client (profile = profile )
584+ paginator = s3 .get_paginator ("list_objects_v2" )
585+ list_kwargs = {"Bucket" : bucket , "Prefix" : prefix }
586+ if not recursive :
587+ list_kwargs ["Delimiter" ] = "/"
588+
589+ keys = []
590+ for page in paginator .paginate (** list_kwargs ):
591+ for obj in page .get ("Contents" , []) or []:
592+ key = obj .get ("Key" , "" )
593+ if not key or key .endswith ("/" ):
594+ continue
595+ keys .append (key )
500596
501- if recursive :
502- # Add the --recursive flag if requested
503- cmd .append ('--recursive' )
504-
505- # Run the `aws s3 ls` command and capture the output
506- run_ls = subprocess .run (cmd , capture_output = True , check = True )
507- # Split the output into lines and strip whitespace from each line
508- contents = run_ls .stdout .decode ('utf-8' ).split ('\n ' )
509- contents = [x .strip () for x in contents if x ]
510- # Remove the timestamp from lines that begin with a timestamp
511- contents = [re .sub (r'^\S+\s+\S+\s+\d+\s+' , '' , x )
512- if re .match (r'^\S+\s+\S+\s+\d+\s+' , x ) else x for x in contents ]
513- # Construct a full S3 path for each file in the contents list
514- contents = [posixpath .join (s3_url_prefix or '' , x )
515- for x in contents if not x .endswith ('/' )]
516597 # Apply the filter, if specified
517- if filter :
518- contents = [x for x in contents if filter in x ]
519- # Remove redundant directory names in the S3 URL
520- if recursive :
521- # Get the main directory name from the S3 URL
522- main_dir = "/" .join (bucket_path .split ('/' )[3 :])
523- # Remove the redundant directory names from each file path
524- contents = [x .replace (f'{ main_dir } ' , '' ).replace (
525- '//' , '/' ) for x in contents ]
526- # Print debugging information, if requested
598+ if filter_str :
599+ keys = [k for k in keys if filter_str in k ]
600+
601+ if relative :
602+ # Strip the bucket-level prefix so keys are relative to dataset_path.
603+ if prefix :
604+ keys = [k [len (prefix ):] if k .startswith (prefix ) else k for k in keys ]
605+ keys = [k .lstrip ('/' ) for k in keys ]
606+ contents = keys
607+ else :
608+ # Backwards-compatible default: full s3://bucket/key paths.
609+ contents = [f"s3://{ bucket } /{ k } " for k in keys ]
610+
527611 if debug :
528612 print ("contents = \n " , contents )
529- # Return the list of S3 paths to files
613+
530614 return contents
531615
532616
617+ # 7 days (SigV4 max) so shard URLs outlast long training runs. Override per
618+ # call or via STABLE_AUDIO_TOOLS_S3_PRESIGN_EXPIRY.
619+ _DEFAULT_PRESIGN_EXPIRY_SECONDS = 7 * 24 * 3600
620+
621+
533622def get_all_s3_urls (
534- names = [], # list of all valid [LAION AudioDataset] dataset names
535- # list of subsets you want from those datasets, e.g. ['train','valid']
536- subsets = ['' ],
623+ names = None , # list of [LAION AudioDataset] dataset names; None -> []
624+ subsets = None , # list of subsets, e.g. ['train','valid']; None -> ['']
537625 s3_url_prefix = None , # prefix for those dataset names
538626 recursive = True , # recursively list all tar files in all subdirs
539627 filter_str = 'tar' , # only grab files with this substring
540628 # print debugging info -- note: info displayed likely to change at dev's whims
541629 debug = False ,
542- profiles = {}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
630+ profiles = None , # dict of profiles per name, e.g. {'dataset1': 'profile1'}; None -> {}
631+ presign_expiry_seconds = None , # presigned-URL lifetime; None -> env var or default
543632):
544- "get urls of shards (tar files) for multiple datasets in one s3 bucket"
633+ """Get urls of shards (tar files) for multiple datasets in one s3 bucket.
634+
635+ Shards are fetched via presigned URLs handed to ``curl``. The URL carries
636+ short-lived credentials on the command line, so on shared/multi-tenant hosts
637+ it can be visible to other local users (e.g. via ``ps``); keep
638+ ``presign_expiry_seconds`` short in those environments.
639+ """
640+ names = [] if names is None else names
641+ subsets = ['' ] if subsets is None else subsets
642+ profiles = profiles or {}
643+ if presign_expiry_seconds is None :
644+ presign_expiry_seconds = os .environ .get (
645+ "STABLE_AUDIO_TOOLS_S3_PRESIGN_EXPIRY" , _DEFAULT_PRESIGN_EXPIRY_SECONDS )
646+ try :
647+ presign_expiry_seconds = int (presign_expiry_seconds )
648+ except (TypeError , ValueError ):
649+ raise ValueError (
650+ "presign_expiry_seconds (or STABLE_AUDIO_TOOLS_S3_PRESIGN_EXPIRY) must be "
651+ f"an integer number of seconds, got: { presign_expiry_seconds !r} "
652+ )
653+ if presign_expiry_seconds <= 0 :
654+ raise ValueError (
655+ f"presign_expiry_seconds must be positive, got: { presign_expiry_seconds } "
656+ )
545657 urls = []
546658 for name in names :
547659 # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
@@ -559,23 +671,33 @@ def get_all_s3_urls(
559671 # Get the list of tar files in the current subset directory
560672 profile = profiles .get (name , None )
561673 tar_list = get_s3_contents (
562- subset_str , s3_url_prefix = None , recursive = recursive , filter = filter_str , debug = debug , profile = profile )
674+ subset_str , s3_url_prefix = None , recursive = recursive , filter_str = filter_str , debug = debug , profile = profile , relative = True )
675+ # Reuse the cached S3 client (shared with get_s3_contents) for presigning.
676+ s3_client = _get_s3_client (profile = profile )
563677 for tar in tar_list :
564- # Escape spaces and parentheses in the tar filename for use in the shell command
565- tar = tar .replace (" " , "\ " ).replace (
566- "(" , "\(" ).replace (")" , "\)" )
567- # Construct the S3 path to the current tar file
568- s3_path = posixpath .join (name , subset , tar ) + " -"
569- # Construct the AWS CLI command to download the current tar file
678+ # Construct the full s3:// URL for the current tar file.
570679 if s3_url_prefix is None :
571- request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp { s3_path } "
680+ full_s3_url = posixpath . join ( name , subset , tar )
572681 else :
573- request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp { posixpath .join (s3_url_prefix , s3_path )} "
574- if profiles .get (name ):
575- request_str += f" --profile { profiles .get (name )} "
682+ full_s3_url = posixpath .join (s3_url_prefix , name , subset , tar )
683+
684+ bucket , key = _parse_s3_url (full_s3_url )
685+
686+ # Presigned GET URL works against AWS and any S3-compatible
687+ # endpoint when AWS_ENDPOINT_URL is set. Expiry is configurable
688+ # so long training runs do not outlive their shard URLs.
689+ presigned = s3_client .generate_presigned_url (
690+ "get_object" ,
691+ Params = {"Bucket" : bucket , "Key" : key },
692+ ExpiresIn = presign_expiry_seconds ,
693+ )
694+ # --retry restores the transient-failure resilience the AWS CLI
695+ # had; shlex.quote keeps URL contents safe in the pipe: shell command.
696+ request_str = f"pipe:curl -fsSL --retry 5 { shlex .quote (presigned )} "
576697 if debug :
577- print ("request_str = " , request_str )
578- # Add the constructed URL to the list of URLs
698+ # Strip the signed query string so signatures are not logged.
699+ redacted = presigned .split ("?" , 1 )[0 ] + "?<redacted>"
700+ print ("request_str = pipe:curl -fsSL --retry 5" , shlex .quote (redacted ))
579701 urls .append (request_str )
580702 return urls
581703
0 commit comments