Skip to content

Commit f675562

Browse files
committed
Update progress indicators
1 parent e1f9acc commit f675562

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

app.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,18 +264,18 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
264264
descriptions=[text],
265265
melody_wavs=melody,
266266
melody_sample_rate=sr,
267-
progress=True, progress_callback=gr.Progress(track_tqdm=True)
267+
progress=False, progress_callback=gr.Progress(track_tqdm=True)
268268
)
269269
# All output_segments are populated, so we can break the loop or set duration to 0
270270
break
271271
else:
272272
#output = MODEL.generate(descriptions=[text], progress=False)
273273
if not output_segments:
274-
next_segment = MODEL.generate(descriptions=[text], progress=True, progress_callback=gr.Progress(track_tqdm=True))
274+
next_segment = MODEL.generate(descriptions=[text], progress=False, progress_callback=gr.Progress(track_tqdm=True))
275275
duration -= segment_duration
276276
else:
277277
last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
278-
next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=True, progress_callback=gr.Progress(track_tqdm=True))
278+
next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=False, progress_callback=gr.Progress(track_tqdm=True))
279279
duration -= segment_duration - overlap
280280
if next_segment != None:
281281
output_segments.append(next_segment)

audiocraft/models/musicgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
416416
if self._progress_callback is not None:
417417
# Note that total_gen_len might be quite wrong depending on the
418418
# codebook pattern used, but with delay it is almost accurate.
419-
self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens}/{tokens_to_generate} seconds")
419+
self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
420420
if progress_callback is not None:
421421
# Update Gradio progress bar
422-
progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens}/{tokens_to_generate} seconds")
422+
progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
423423
if progress:
424424
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
425425

0 commit comments

Comments
 (0)