@@ -109,7 +109,11 @@ def get_collate_fn(
109109 wav_dir : Optional [str ] = None ,
110110 frames_per_seg : Optional [int ] = None ,
111111 ext_audio : str = ".wav" ,
112+ predict_duration : bool = False ,
112113):
114+ if predict_duration :
115+ assert frames_per_seg is None
116+
113117 def parse_item (item : Dict [str , Any ]):
114118 input_ids = item ["units" ] + 1 # 0: pad
115119 spectrogram_labels = item ["spectrogram" ]
@@ -122,6 +126,11 @@ def parse_item(item: Dict[str, Any]):
122126 wav , sr = torchaudio .load (wav_path )
123127 wav = wav .squeeze (0 )
124128
129+ if predict_duration :
130+ input_ids , durations = torch .unique_consecutive (input_ids , return_counts = True )
131+ else :
132+ durations = torch .ones_like (input_ids )
133+
125134 if frames_per_seg is not None :
126135 diff = len (input_ids ) - frames_per_seg
127136
@@ -130,30 +139,34 @@ def parse_item(item: Dict[str, Any]):
130139 input_ids = input_ids [start : start + frames_per_seg ]
131140 spectrogram_labels = spectrogram_labels [start : start + frames_per_seg ]
132141
133- return input_ids , spectrogram_labels , transcript , id , wav
142+ return input_ids , spectrogram_labels , durations , transcript , id , wav
134143
135144 def collate_fn (batch ):
136145 input_ids = []
137146 spectrogram_labels = []
147+ duration_labels = []
138148 transcripts = []
139149 names = []
140150 input_values = []
141151
142152 for item in batch :
143- units , spectrogram , transcript , id , wav = parse_item (item )
153+ units , spectrogram , durations , transcript , id , wav = parse_item (item )
144154 input_ids .append (units )
145155 spectrogram_labels .append (spectrogram )
156+ duration_labels .append (durations )
146157 transcripts .append (transcript )
147158 names .append (id )
148159 input_values .append (wav )
149160
150161 input_ids = pad_sequence (input_ids , batch_first = True )
151162 spectrogram_labels = pad_sequence (spectrogram_labels , batch_first = True , padding_value = - 100 )
163+ duration_labels = pad_sequence (duration_labels , batch_first = True )
152164 input_values = pad_sequence (input_values , batch_first = True )
153165
154166 return {
155167 "input_ids" : input_ids ,
156168 "spectrogram_labels" : spectrogram_labels ,
169+ "duration_labels" : duration_labels ,
157170 "transcripts" : transcripts ,
158171 "names" : names ,
159172 "input_values" : input_values ,
0 commit comments