Skip to content

How to input paired audio and video into the pipeline efficiently? #5935

@Ash-one

Description

@Ash-one

Describe the question.

My use case is very easy to describe: paired audio fbank features and video frames. But when the dataset comes to a huge number of samples, it becomes a problem that the loading phase takes to long.

I have tried many ways to load them from the HDD as the limited resources, in the following cases the audio fbank feature saved as a .npy file:

  1. use DALI externel source, load both .mp4 and fbank .npy with numpy, and then decode the video on gpu
code
class DALIDatasetCallable:
    def __init__(
        self,
        ann_file,
        batch_size=1,
        shuffled=True,
        media_type='audio_video',
        shard_id=0,
        num_shards=1,
        **kwargs,
    ):
        self.media_type = media_type
        self.batch_size = batch_size
        with open(ann_file.anno_path, 'r') as f:
            self.label_file = json.load(f)
        self.data_root = ann_file.data_root
        
        self.indices = np.arange(len(self.label_file))
        
        if shuffled:
            np.random.shuffle(self.indices)
            self.label_file = [self.label_file[i] for i in self.indices]
        
        self.filenames = list(zip(self.label_file, self.indices))
        self.epoch = 0
        
        self.shard_id = shard_id
        self.num_shards = num_shards

        self.shard_size = len(self.label_file) // num_shards
        self.shard_offset = self.shard_size * shard_id

        self.full_iterations = self.shard_size // batch_size
        self.perm = None 
        self.last_seen_epoch = (
            None
        )
        
        
    def __len__(self):
        return self.full_iterations

    def reset(self):
        self.perm = None
        self.last_seen_epoch = None
    
    def __call__(self, sample_info):
        # print(sample_info.epoch_idx, sample_info.iteration,sample_info.idx_in_epoch)
        if sample_info.iteration >= self.full_iterations:
            raise StopIteration
        if self.last_seen_epoch != sample_info.epoch_idx:
            self.last_seen_epoch = sample_info.epoch_idx
            self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx)
            self.perm = self.perm.permutation(len(self.filenames))
        sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]

        sample, index = self.filenames[sample_idx]
        
        vfilename, afilename, frame_count = os.path.join(self.data_root,sample['video']), \
                                            os.path.join(self.data_root,sample['audio']), \
                                            sample['num_frames']
                                            
        video = np.fromfile(vfilename, dtype=np.uint8)
        fbank = np.load(afilename)
        frame_idxs = np.array(get_frame_indices(16,frame_count),dtype=np.int32)
        index = np.array([np.int32(index)])
        
        return video, fbank, frame_idxs, index

@pipeline_def(py_num_workers=4, py_start_method="spawn")
def AudioVideoPipeline(eii, parallel):
    vid, audio, frame_idxs, index = fn.external_source(device="cpu", batch=False, 
                                                        num_outputs=4, source=eii, parallel=parallel,
                                                        prefetch_queue_depth=4)
    video = fn.experimental.decoders.video(vid, device="mixed", frames=frame_idxs)
    resize_v = fn.resize(video,resize_shorter=224, device='gpu')
    crop_v = fn.crop_mirror_normalize(resize_v,crop=(224,224),device='gpu',
                                        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                        output_layout="FCHW")
    
    return audio, crop_v, index
  1. tar all the video .mp4 files and fbank .npy files, use DALI webdataset to load them,
code
def get_indices(num_frames, vlen):
    indices = np.array(get_frame_indices(num_frames, vlen),dtype=np.int32)
    return indices
    
@pipeline_def(batch_size=4, num_threads=4)
def WebDatasetAVPipeline(wds_data, index_paths, shard_id, num_shards):
    vid, fbank, vlen = fn.readers.webdataset(paths=wds_data, 
                                             index_paths=index_paths,
                                             dtypes=[types.UINT8, types.FLOAT, types.INT16], 
                                             ext=["mp4","fbank","npy"], 
                                             num_shards=num_shards,
                                             shard_id=shard_id,
                                             random_shuffle=False,
                                             prefetch_queue_depth=4,
                                             name="reader",
                                             missing_component_behavior="error")
    
    indices = fn.python_function(16, vlen[0], function=get_indices, num_outputs=1)
    
    fbank = fn.reshape(fbank, shape=(1024, 128))
    video = fn.experimental.decoders.video(vid, device="mixed", frames=indices)
    resize_v = fn.resize(video,resize_shorter=224, device='gpu')
    crop_v = fn.crop_mirror_normalize(resize_v,crop=(224,224),device='gpu',
                                        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                        output_layout="FCHW")
    index = types.Constant(0,dtype=types.UINT8)
    return fbank, crop_v, index
  1. extract the video from mp4 file into 16 frames as .jpg files, then use DALI externel source to load paired frames and fbank features which are loaded by numpy
code
class DALIFrameDatasetCallable:
    def __init__(
        self,
        ann_file,
        batch_size=1,
        shuffled=True,
        media_type='audio_video',
        shard_id=0,
        num_shards=1,
        **kwargs,
    ):
        self.media_type = media_type
        self.batch_size = batch_size
        with open(ann_file.anno_path, 'r') as f:
            self.label_file = json.load(f)
        self.data_root = ann_file.data_root
        assert "frames" in self.label_file[0], "Please make sure the label file has frames"
        
        self.indices = np.arange(len(self.label_file))
        
        if shuffled:
            np.random.shuffle(self.indices)
            self.label_file = [self.label_file[i] for i in self.indices]
        
        self.filenames = list(zip(self.label_file, self.indices))
        self.epoch = 0

        self.shard_id = shard_id
        self.num_shards = num_shards

        self.shard_size = len(self.label_file) // num_shards
        self.shard_offset = self.shard_size * shard_id

        self.full_iterations = self.shard_size // batch_size
        self.perm = None 
        self.last_seen_epoch = (
            None
        )
        
        
    def __len__(self):
        return self.full_iterations
    def reset(self):
        self.perm = None
        self.last_seen_epoch = None
    
    @staticmethod
    def load_video_frames(path, frame_count):
        jpg_paths = sorted(glob(os.path.join(path, "*.jpg")))
        # load with cv2
        frames = [cv2.imread(jpg_path)[:, :, ::-1] for jpg_path in jpg_paths[:frame_count]]
        return np.stack(frames)
    
    def __call__(self, sample_info):
        # print(sample_info.epoch_idx, sample_info.iteration,sample_info.idx_in_epoch)
        if sample_info.iteration >= self.full_iterations:
            # Indicate end of the epoch
            raise StopIteration
        if self.last_seen_epoch != sample_info.epoch_idx:
            self.last_seen_epoch = sample_info.epoch_idx
            self.perm = np.random.default_rng(seed=42 + sample_info.epoch_idx)
            self.perm = self.perm.permutation(len(self.filenames))
        sample_idx = self.perm[sample_info.idx_in_epoch + self.shard_offset]

        sample, index = self.filenames[sample_idx]
        
        vfilename, afilename = os.path.join(self.data_root,sample['frames']), \
                                            os.path.join(self.data_root,sample['audio'])
                                            
        video = self.load_video_frames(vfilename, 16)
        fbank = np.load(afilename,allow_pickle=True)
        index = np.array([np.int32(index)])
        
        return video, fbank, index

@pipeline_def(py_num_workers=4, py_start_method="spawn")
def AudioFramesPipeline(eii, parallel):
    vid, audio, index = fn.external_source(device="cpu", batch=False, 
                                                        num_outputs=3, source=eii, parallel=parallel,
                                                        prefetch_queue_depth=4)
    resize_v = fn.resize(vid.gpu(), resize_shorter=224, device='gpu')
    crop_v = fn.crop_mirror_normalize(resize_v,crop=(224,224),device='gpu',
                                        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
                                        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
                                        output_layout="FCHW")
    
    return audio, crop_v, index

There is another case I have not tested that is input the filelists to the pipeline, but I do not know the mechanism of DALI. When using 2 filelists in the same order (for example, one file contains the paths of frames, and another file contains the paths of npy, with the paths at the same index position in both files being paired video and audio), does the pipeline for reading videos output samples at the same positions as the pipeline for reading npy?

But all of above methods showed poor performance... I have no idea now for how to accelerate the loading phase in training😮‍💨 Now only less than 10% time is spending on gpu training...

I am unable to upgrade the server's hard drive to SSD, so I want to know how to input paired audio and video into the pipeline more efficiently?

Looking forward to your reply, thanks!

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions