|
89 | 89 | # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ] |
90 | 90 | # fix_duration = None # use origin text duration |
91 | 91 |
|
| 92 | +# audio_to_edit = "src/f5_tts/infer/examples/basic/basic_ref_zh.wav" |
| 93 | +# origin_text = "对,这就是我,万人敬仰的太乙真人。" |
| 94 | +# target_text = "对,这就是你,万人敬仰的李白金星。" |
| 95 | +# parts_to_edit = [[1.500, 2.784], [4.083, 6.760]] |
| 96 | +# fix_duration = [1.284, 2.677] |
| 97 | + |
92 | 98 |
|
93 | 99 | # -------------------------------------------------# |
94 | 100 |
|
|
138 | 144 | if sr != target_sample_rate: |
139 | 145 | resampler = torchaudio.transforms.Resample(sr, target_sample_rate) |
140 | 146 | audio = resampler(audio) |
141 | | -offset = 0 |
142 | | -audio_ = torch.zeros(1, 0) |
143 | | -edit_mask = torch.zeros(1, 0, dtype=torch.bool) |
| 147 | + |
| 148 | +# Convert to mel spectrogram FIRST (on clean original audio) |
| 149 | +# This avoids boundary artifacts from mel windows straddling zeros and real audio |
| 150 | +audio = audio.to(device) |
| 151 | +with torch.inference_mode(): |
| 152 | + original_mel = model.mel_spec(audio) # (batch, n_mel, n_frames) |
| 153 | + original_mel = original_mel.permute(0, 2, 1) # (batch, n_frames, n_mel) |
| 154 | + |
| 155 | +# Build mel_cond and edit_mask at FRAME level |
| 156 | +# Insert zero frames in mel domain instead of zero samples in wav domain |
| 157 | +offset_frame = 0 |
| 158 | +mel_cond = torch.zeros(1, 0, n_mel_channels, device=device) |
| 159 | +edit_mask = torch.zeros(1, 0, dtype=torch.bool, device=device) |
| 160 | +fix_dur_list = fix_duration.copy() if fix_duration is not None else None |
| 161 | + |
144 | 162 | for part in parts_to_edit: |
145 | 163 | start, end = part |
146 | | - part_dur = end - start if fix_duration is None else fix_duration.pop(0) |
147 | | - part_dur = part_dur * target_sample_rate |
148 | | - start = start * target_sample_rate |
149 | | - audio_ = torch.cat((audio_, audio[:, round(offset) : round(start)], torch.zeros(1, round(part_dur))), dim=-1) |
| 164 | + part_dur_sec = end - start if fix_dur_list is None else fix_dur_list.pop(0) |
| 165 | + |
| 166 | + # Convert to frames (this is the authoritative unit) |
| 167 | + start_frame = round(start * target_sample_rate / hop_length) |
| 168 | + end_frame = round(end * target_sample_rate / hop_length) |
| 169 | + part_dur_frames = round(part_dur_sec * target_sample_rate / hop_length) |
| 170 | + |
| 171 | + # Number of frames for the kept (non-edited) region |
| 172 | + keep_frames = start_frame - offset_frame |
| 173 | + |
| 174 | + # Build mel_cond: original mel frames + zero frames for edit region |
| 175 | + mel_cond = torch.cat( |
| 176 | + ( |
| 177 | + mel_cond, |
| 178 | + original_mel[:, offset_frame:start_frame, :], |
| 179 | + torch.zeros(1, part_dur_frames, n_mel_channels, device=device), |
| 180 | + ), |
| 181 | + dim=1, |
| 182 | + ) |
150 | 183 | edit_mask = torch.cat( |
151 | 184 | ( |
152 | 185 | edit_mask, |
153 | | - torch.ones(1, round((start - offset) / hop_length), dtype=torch.bool), |
154 | | - torch.zeros(1, round(part_dur / hop_length), dtype=torch.bool), |
| 186 | + torch.ones(1, keep_frames, dtype=torch.bool, device=device), |
| 187 | + torch.zeros(1, part_dur_frames, dtype=torch.bool, device=device), |
155 | 188 | ), |
156 | 189 | dim=-1, |
157 | 190 | ) |
158 | | - offset = end * target_sample_rate |
159 | | -audio = torch.cat((audio_, audio[:, round(offset) :]), dim=-1) |
160 | | -edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value=True) |
161 | | -audio = audio.to(device) |
162 | | -edit_mask = edit_mask.to(device) |
| 191 | + offset_frame = end_frame |
| 192 | + |
| 193 | +# Append remaining mel frames after last edit |
| 194 | +mel_cond = torch.cat((mel_cond, original_mel[:, offset_frame:, :]), dim=1) |
| 195 | +edit_mask = F.pad(edit_mask, (0, mel_cond.shape[1] - edit_mask.shape[-1]), value=True) |
163 | 196 |
|
164 | 197 | # Text |
165 | 198 | text_list = [target_text] |
|
170 | 203 | print(f"text : {text_list}") |
171 | 204 | print(f"pinyin: {final_text_list}") |
172 | 205 |
|
173 | | -# Duration |
174 | | -ref_audio_len = 0 |
175 | | -duration = audio.shape[-1] // hop_length |
| 206 | +# Duration - use mel_cond length (not raw audio length) |
| 207 | +duration = mel_cond.shape[1] |
176 | 208 |
|
177 | | -# Inference |
| 209 | +# Inference - pass mel_cond directly (not wav) |
178 | 210 | with torch.inference_mode(): |
179 | 211 | generated, trajectory = model.sample( |
180 | | - cond=audio, |
| 212 | + cond=mel_cond, # Now passing mel directly, not wav |
181 | 213 | text=final_text_list, |
182 | 214 | duration=duration, |
183 | 215 | steps=nfe_step, |
|
190 | 222 |
|
191 | 223 | # Final result |
192 | 224 | generated = generated.to(torch.float32) |
193 | | - generated = generated[:, ref_audio_len:, :] |
194 | 225 | gen_mel_spec = generated.permute(0, 2, 1) |
195 | 226 | if mel_spec_type == "vocos": |
196 | 227 | generated_wave = vocoder.decode(gen_mel_spec).cpu() |
|
0 commit comments