|
1 | 1 | import importlib |
2 | | -import numpy as np |
3 | 2 | import io |
4 | 3 | import json |
5 | 4 | import os |
6 | | -import dill |
7 | | -import posixpath |
8 | 5 | import random |
9 | | -import re |
10 | | -import subprocess |
11 | 6 | import time |
| 7 | +from os import path |
| 8 | +from typing import Callable, List, Optional |
| 9 | + |
| 10 | +import dill |
| 11 | +import numpy as np |
12 | 12 | import torch |
13 | 13 | import torchaudio |
14 | 14 | import webdataset as wds |
15 | | - |
16 | | -from os import path |
17 | 15 | from torch import nn |
18 | 16 | from torchaudio import transforms as T |
19 | | -from typing import Optional, Callable, List |
20 | 17 |
|
21 | | -from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm, strip_trailing_silence |
| 18 | +from .utils import Mono, PadCrop_Normalized_T, PhaseFlipper, Stereo, VolumeNorm, strip_trailing_silence |
22 | 19 |
|
23 | 20 | AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") |
24 | 21 |
|
@@ -481,105 +478,23 @@ def __getitem__(self, idx): |
481 | 478 | print(f'Couldn\'t load file {latent_filename}: {e}') |
482 | 479 | return self[random.randrange(len(self))] |
483 | 480 |
|
484 | | -# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py |
485 | | - |
486 | | -def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): |
487 | | - """ |
488 | | - Returns a list of full S3 paths to files in a given S3 bucket and directory path. |
489 | | - """ |
490 | | - # Ensure dataset_path ends with a trailing slash |
491 | | - if dataset_path != '' and not dataset_path.endswith('/'): |
492 | | - dataset_path += '/' |
493 | | - # Use posixpath to construct the S3 URL path |
494 | | - bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) |
495 | | - # Construct the `aws s3 ls` command |
496 | | - cmd = ['aws', 's3', 'ls', bucket_path] |
497 | | - |
498 | | - if profile is not None: |
499 | | - cmd.extend(['--profile', profile]) |
500 | | - |
501 | | - if recursive: |
502 | | - # Add the --recursive flag if requested |
503 | | - cmd.append('--recursive') |
504 | | - |
505 | | - # Run the `aws s3 ls` command and capture the output |
506 | | - run_ls = subprocess.run(cmd, capture_output=True, check=True) |
507 | | - # Split the output into lines and strip whitespace from each line |
508 | | - contents = run_ls.stdout.decode('utf-8').split('\n') |
509 | | - contents = [x.strip() for x in contents if x] |
510 | | - # Remove the timestamp from lines that begin with a timestamp |
511 | | - contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) |
512 | | - if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] |
513 | | - # Construct a full S3 path for each file in the contents list |
514 | | - contents = [posixpath.join(s3_url_prefix or '', x) |
515 | | - for x in contents if not x.endswith('/')] |
516 | | - # Apply the filter, if specified |
517 | | - if filter: |
518 | | - contents = [x for x in contents if filter in x] |
519 | | - # Remove redundant directory names in the S3 URL |
520 | | - if recursive: |
521 | | - # Get the main directory name from the S3 URL |
522 | | - main_dir = "/".join(bucket_path.split('/')[3:]) |
523 | | - # Remove the redundant directory names from each file path |
524 | | - contents = [x.replace(f'{main_dir}', '').replace( |
525 | | - '//', '/') for x in contents] |
526 | | - # Print debugging information, if requested |
527 | | - if debug: |
528 | | - print("contents = \n", contents) |
529 | | - # Return the list of S3 paths to files |
530 | | - return contents |
531 | | - |
532 | | - |
533 | | -def get_all_s3_urls( |
534 | | - names=[], # list of all valid [LAION AudioDataset] dataset names |
535 | | - # list of subsets you want from those datasets, e.g. ['train','valid'] |
536 | | - subsets=[''], |
537 | | - s3_url_prefix=None, # prefix for those dataset names |
538 | | - recursive=True, # recursively list all tar files in all subdirs |
539 | | - filter_str='tar', # only grab files with this substring |
540 | | - # print debugging info -- note: info displayed likely to change at dev's whims |
541 | | - debug=False, |
542 | | - profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} |
543 | | -): |
544 | | - "get urls of shards (tar files) for multiple datasets in one s3 bucket" |
545 | | - urls = [] |
546 | | - for name in names: |
547 | | - # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list |
548 | | - if s3_url_prefix is None: |
549 | | - contents_str = name |
550 | | - else: |
551 | | - # Construct the S3 path using the s3_url_prefix and the current name value |
552 | | - contents_str = posixpath.join(s3_url_prefix, name) |
553 | | - if debug: |
554 | | - print(f"get_all_s3_urls: {contents_str}:") |
555 | | - for subset in subsets: |
556 | | - subset_str = posixpath.join(contents_str, subset) |
557 | | - if debug: |
558 | | - print(f"subset_str = {subset_str}") |
559 | | - # Get the list of tar files in the current subset directory |
560 | | - profile = profiles.get(name, None) |
561 | | - tar_list = get_s3_contents( |
562 | | - subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) |
563 | | - for tar in tar_list: |
564 | | - # Escape spaces and parentheses in the tar filename for use in the shell command |
565 | | - tar = tar.replace(" ", "\ ").replace( |
566 | | - "(", "\(").replace(")", "\)") |
567 | | - # Construct the S3 path to the current tar file |
568 | | - s3_path = posixpath.join(name, subset, tar) + " -" |
569 | | - # Construct the AWS CLI command to download the current tar file |
570 | | - if s3_url_prefix is None: |
571 | | - request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" |
572 | | - else: |
573 | | - request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" |
574 | | - if profiles.get(name): |
575 | | - request_str += f" --profile {profiles.get(name)}" |
576 | | - if debug: |
577 | | - print("request_str = ", request_str) |
578 | | - # Add the constructed URL to the list of URLs |
579 | | - urls.append(request_str) |
580 | | - return urls |
581 | | - |
582 | | - |
| 481 | +# S3 helpers live in the import-light s3_utils module (no torch) so it can also |
| 482 | +# run as a `pipe:` subprocess that streams individual shards. Re-exported here |
| 483 | +# for backwards compatibility. |
| 484 | +from .s3_utils import ( # noqa: E402 |
| 485 | + _build_s3_client, |
| 486 | + _build_user_agent_extra, |
| 487 | + _get_s3_client, |
| 488 | + _parse_s3_url, |
| 489 | + _user_agent, |
| 490 | + get_all_s3_urls, |
| 491 | + get_s3_contents, |
| 492 | + shard_pipe_command, |
| 493 | + stream_object, |
| 494 | +) |
| 495 | + |
| 496 | + |
| 497 | +# WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py |
583 | 498 | def log_and_continue(exn): |
584 | 499 | """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" |
585 | 500 | print(f"Handling webdataset error ({repr(exn)}). Ignoring.") |
|
0 commit comments