4040import torch
4141import typer
4242import yaml
43- from pyannote .audio import Audio , Pipeline , Model
43+ from pyannote .audio import Audio , Model , Pipeline
4444from pyannote .core import Annotation
4545from pyannote .metrics .base import BaseMetric
4646from pyannote .metrics .diarization import DiarizationErrorRate , JaccardErrorRate
@@ -285,16 +285,20 @@ def download(
285285 help = "Pretrained pipeline (e.g. pyannote/speaker-diarization-community-1)"
286286 ),
287287 ],
288- token : Annotated [
289- str ,
290- typer .Argument (
291- help = "Huggingface token to be used for downloading from Huggingface hub."
288+ revision : Annotated [
289+ Optional [ str ] ,
290+ typer .Option (
291+ help = "Pretrained pipeline revision." ,
292292 ),
293- ],
293+ ] = None ,
294+ token : Annotated [
295+ Optional [str ],
296+ typer .Argument (help = "Huggingface token." ),
297+ ] = None ,
294298 cache : Annotated [
295- Path ,
299+ Optional [ Path ] ,
296300 typer .Option (
297- help = "Path to the folder where files downloaded from Huggingface hub are stored." ,
301+ help = "Path to the folder where files downloaded from Huggingface are stored." ,
298302 exists = True ,
299303 dir_okay = True ,
300304 file_okay = False ,
@@ -309,7 +313,7 @@ def download(
309313
310314 # load pretrained pipeline
311315 pretrained_pipeline = Pipeline .from_pretrained (
312- pipeline , token = token , cache_dir = cache
316+ pipeline , revision = revision , token = token , cache_dir = cache
313317 )
314318 if pretrained_pipeline is None :
315319 print (f"Could not load pretrained pipeline from { pipeline } ." )
@@ -335,7 +339,7 @@ def apply(
335339 ),
336340 ],
337341 into : Annotated [
338- Path ,
342+ Optional [ Path ] ,
339343 typer .Option (
340344 help = "Path to file or directory where results are saved." ,
341345 exists = False ,
@@ -345,27 +349,39 @@ def apply(
345349 resolve_path = True ,
346350 ),
347351 ] = None ,
348- device : Annotated [
349- Device , typer .Option (help = "Accelerator to use (CPU, CUDA, MPS)" )
350- ] = Device .AUTO ,
352+ revision : Annotated [
353+ Optional [str ],
354+ typer .Option (
355+ help = "Pretrained pipeline revision." ,
356+ ),
357+ ] = None ,
358+ token : Annotated [
359+ Optional [str ],
360+ typer .Argument (help = "Huggingface token." ),
361+ ] = None ,
351362 cache : Annotated [
352- Path ,
363+ Optional [ Path ] ,
353364 typer .Option (
354- help = "Path to the folder where files downloaded from Huggingface hub are stored." ,
365+ help = "Path to the folder where files downloaded from Huggingface are stored." ,
355366 exists = True ,
356367 dir_okay = True ,
357368 file_okay = False ,
358369 writable = True ,
359370 resolve_path = True ,
360371 ),
361372 ] = None ,
373+ device : Annotated [
374+ Device , typer .Option (help = "Accelerator to use (CPU, CUDA, MPS)" )
375+ ] = Device .AUTO ,
362376):
363377 """
364378 Apply a pretrained PIPELINE to an AUDIO file or directory
365379 """
366380
367381 # load pretrained pipeline
368- pretrained_pipeline = Pipeline .from_pretrained (pipeline , cache_dir = cache )
382+ pretrained_pipeline = Pipeline .from_pretrained (
383+ pipeline , revision = revision , token = token , cache_dir = cache
384+ )
369385 if pretrained_pipeline is None :
370386 print (f"Could not load pretrained pipeline from { pipeline } ." )
371387 raise typer .exit (code = 1 )
@@ -375,7 +391,6 @@ def apply(
375391 pretrained_pipeline .to (torch_device )
376392
377393 if audio .is_dir ():
378-
379394 if into is None or not into .is_dir ():
380395 typer .echo ("When AUDIO is a directory, INTO must also be a directory." )
381396 raise typer .exit (code = 1 )
@@ -385,7 +400,6 @@ def apply(
385400 jsons : list [Path | None ] = [into / (path .stem + ".json" ) for path in inputs ]
386401
387402 else :
388-
389403 if not (into is None or into .is_file ()):
390404 typer .echo ("When AUDIO is a file, INTO must also be a file." )
391405 raise typer .exit (code = 1 )
@@ -395,7 +409,6 @@ def apply(
395409 jsons : list [Path | None ] = [into .with_suffix (".json" ) if into else None ]
396410
397411 for current_input , current_rttm , current_json in zip (inputs , rttms , jsons ):
398-
399412 prediction = pretrained_pipeline (current_input )
400413
401414 speaker_diarization = get_diarization (prediction )
@@ -522,6 +535,27 @@ def benchmark(
522535 case_sensitive = False ,
523536 ),
524537 ] = Subset .test ,
538+ revision : Annotated [
539+ Optional [str ],
540+ typer .Option (
541+ help = "Pretrained pipeline revision." ,
542+ ),
543+ ] = None ,
544+ token : Annotated [
545+ Optional [str ],
546+ typer .Argument (help = "Huggingface token." ),
547+ ] = None ,
548+ cache : Annotated [
549+ Optional [Path ],
550+ typer .Option (
551+ help = "Path to the folder where files downloaded from Huggingface are stored." ,
552+ exists = True ,
553+ dir_okay = True ,
554+ file_okay = False ,
555+ writable = True ,
556+ resolve_path = True ,
557+ ),
558+ ] = None ,
525559 device : Annotated [
526560 Device , typer .Option (help = "Accelerator to use (CPU, CUDA, MPS)" )
527561 ] = Device .AUTO ,
@@ -538,17 +572,6 @@ def benchmark(
538572 num_speakers : Annotated [
539573 NumSpeakers , typer .Option (help = "Number of speakers (oracle or auto)" )
540574 ] = NumSpeakers .AUTO ,
541- cache : Annotated [
542- Path ,
543- typer .Option (
544- help = "Path to the folder where files downloaded from Huggingface hub are stored." ,
545- exists = True ,
546- dir_okay = True ,
547- file_okay = False ,
548- writable = True ,
549- resolve_path = True ,
550- ),
551- ] = None ,
552575 optimize : Annotated [
553576 bool ,
554577 typer .Option (
@@ -562,10 +585,7 @@ def benchmark(
562585 ),
563586 ] = False ,
564587 per_file : Annotated [
565- bool ,
566- typer .Option (
567- help = "Save one RTTM/JSON file per processed audio file."
568- )
588+ bool , typer .Option (help = "Save one RTTM/JSON file per processed audio file." )
569589 ] = False ,
570590):
571591 """
@@ -578,7 +598,9 @@ def benchmark(
578598 """
579599
580600 # load pretrained pipeline
581- pretrained_pipeline = Pipeline .from_pretrained (pipeline , cache_dir = cache )
601+ pretrained_pipeline = Pipeline .from_pretrained (
602+ pipeline , revision = revision , token = token , cache_dir = cache ,
603+ )
582604 if pretrained_pipeline is None :
583605 print (f"Could not load pretrained pipeline from { pipeline } ." )
584606 raise typer .exit (code = 1 )
@@ -808,15 +830,19 @@ def benchmark(
808830 yaml .dump ({"min_duration_off" : best_min_duration_off }, yml )
809831
810832 if not per_file :
811- optimized_rttm_file = into / f"{ benchmark_name } .OptimizedMinDurationOff.rttm"
833+ optimized_rttm_file = (
834+ into / f"{ benchmark_name } .OptimizedMinDurationOff.rttm"
835+ )
812836
813837 # make sure we don't overwrite previous results
814838 if optimized_rttm_file .exists ():
815839 raise FileExistsError (f"{ optimized_rttm_file } already exists." )
816840
817841 for file in files :
818842 if per_file :
819- optimized_rttm_file = rttm_dir / f"{ file ['uri' ]} .OptimizedMinDurationOff.rttm"
843+ optimized_rttm_file = (
844+ rttm_dir / f"{ file ['uri' ]} .OptimizedMinDurationOff.rttm"
845+ )
820846
821847 with open (optimized_rttm_file , "w" if per_file else "a" ) as rttm :
822848 file ["best_speaker_diarization" ].write_rttm (rttm )
@@ -851,11 +877,11 @@ def strip(
851877 """
852878
853879 keys = [
854- "pytorch-lightning_version" , # * pytorch-lightning needs
855- "hparams_name" , # those values to initialize
856- "hyper_parameters" , # the model architecture
857- "state_dict" , # * actual weights
858- "pyannote.audio" , # * pyannote.audio dependencies
880+ "pytorch-lightning_version" , # * pytorch-lightning needs
881+ "hparams_name" , # those values to initialize
882+ "hyper_parameters" , # the model architecture
883+ "state_dict" , # * actual weights
884+ "pyannote.audio" , # * pyannote.audio dependencies
859885 ]
860886
861887 old_checkpoint = torch .load (
0 commit comments