Skip to content

Commit 5a36140

Browse files
committed
Merge branch 'develop' of github.com:pyannote/pyannote-audio into develop
2 parents a232b09 + f577979 commit 5a36140

5 files changed

Lines changed: 109 additions & 44 deletions

File tree

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# CHANGELOG
22

3+
## Version 4.0.3 (2025-12-07)
4+
5+
- feat(cli): add `--revision` option to most CLI commands
6+
- feat(util): add `Calibration.safe_transform` method (supports NaNs as well as any shape)
7+
- fix(model): fix `Model.from_pretrained` to support `lightning` 2.6+
8+
- setup: update `pyannote-database` dependency to `6.1+`
9+
310
## Version 4.0.2 (2025-11-19)
411

512
- BREAKING(util): make `Binarize.__call__` return `string` tracks (instead of `int`) [@benniekiss](https://github.com/benniekiss/)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ dependencies = [
1616
"opentelemetry-sdk>=1.34.0",
1717
"opentelemetry-exporter-otlp>=1.34.0",
1818
"pyannote-core>=6.0.1",
19-
"pyannote-database>=6.0.0",
19+
"pyannote-database>=6.1.1",
2020
"pyannote-metrics>=4.0.0",
2121
"pyannote-pipeline>=4.0.0",
2222
"pytorch-metric-learning>=2.8.1",

src/pyannote/audio/__main__.py

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import torch
4141
import typer
4242
import yaml
43-
from pyannote.audio import Audio, Pipeline, Model
43+
from pyannote.audio import Audio, Model, Pipeline
4444
from pyannote.core import Annotation
4545
from pyannote.metrics.base import BaseMetric
4646
from pyannote.metrics.diarization import DiarizationErrorRate, JaccardErrorRate
@@ -285,16 +285,20 @@ def download(
285285
help="Pretrained pipeline (e.g. pyannote/speaker-diarization-community-1)"
286286
),
287287
],
288-
token: Annotated[
289-
str,
290-
typer.Argument(
291-
help="Huggingface token to be used for downloading from Huggingface hub."
288+
revision: Annotated[
289+
Optional[str],
290+
typer.Option(
291+
help="Pretrained pipeline revision.",
292292
),
293-
],
293+
] = None,
294+
token: Annotated[
295+
Optional[str],
296+
typer.Argument(help="Huggingface token."),
297+
] = None,
294298
cache: Annotated[
295-
Path,
299+
Optional[Path],
296300
typer.Option(
297-
help="Path to the folder where files downloaded from Huggingface hub are stored.",
301+
help="Path to the folder where files downloaded from Huggingface are stored.",
298302
exists=True,
299303
dir_okay=True,
300304
file_okay=False,
@@ -309,7 +313,7 @@ def download(
309313

310314
# load pretrained pipeline
311315
pretrained_pipeline = Pipeline.from_pretrained(
312-
pipeline, token=token, cache_dir=cache
316+
pipeline, revision=revision, token=token, cache_dir=cache
313317
)
314318
if pretrained_pipeline is None:
315319
print(f"Could not load pretrained pipeline from {pipeline}.")
@@ -335,7 +339,7 @@ def apply(
335339
),
336340
],
337341
into: Annotated[
338-
Path,
342+
Optional[Path],
339343
typer.Option(
340344
help="Path to file or directory where results are saved.",
341345
exists=False,
@@ -345,27 +349,39 @@ def apply(
345349
resolve_path=True,
346350
),
347351
] = None,
348-
device: Annotated[
349-
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
350-
] = Device.AUTO,
352+
revision: Annotated[
353+
Optional[str],
354+
typer.Option(
355+
help="Pretrained pipeline revision.",
356+
),
357+
] = None,
358+
token: Annotated[
359+
Optional[str],
360+
typer.Argument(help="Huggingface token."),
361+
] = None,
351362
cache: Annotated[
352-
Path,
363+
Optional[Path],
353364
typer.Option(
354-
help="Path to the folder where files downloaded from Huggingface hub are stored.",
365+
help="Path to the folder where files downloaded from Huggingface are stored.",
355366
exists=True,
356367
dir_okay=True,
357368
file_okay=False,
358369
writable=True,
359370
resolve_path=True,
360371
),
361372
] = None,
373+
device: Annotated[
374+
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
375+
] = Device.AUTO,
362376
):
363377
"""
364378
Apply a pretrained PIPELINE to an AUDIO file or directory
365379
"""
366380

367381
# load pretrained pipeline
368-
pretrained_pipeline = Pipeline.from_pretrained(pipeline, cache_dir=cache)
382+
pretrained_pipeline = Pipeline.from_pretrained(
383+
pipeline, revision=revision, token=token, cache_dir=cache
384+
)
369385
if pretrained_pipeline is None:
370386
print(f"Could not load pretrained pipeline from {pipeline}.")
371387
raise typer.exit(code=1)
@@ -375,7 +391,6 @@ def apply(
375391
pretrained_pipeline.to(torch_device)
376392

377393
if audio.is_dir():
378-
379394
if into is None or not into.is_dir():
380395
typer.echo("When AUDIO is a directory, INTO must also be a directory.")
381396
raise typer.exit(code=1)
@@ -385,7 +400,6 @@ def apply(
385400
jsons: list[Path | None] = [into / (path.stem + ".json") for path in inputs]
386401

387402
else:
388-
389403
if not (into is None or into.is_file()):
390404
typer.echo("When AUDIO is a file, INTO must also be a file.")
391405
raise typer.exit(code=1)
@@ -395,7 +409,6 @@ def apply(
395409
jsons: list[Path | None] = [into.with_suffix(".json") if into else None]
396410

397411
for current_input, current_rttm, current_json in zip(inputs, rttms, jsons):
398-
399412
prediction = pretrained_pipeline(current_input)
400413

401414
speaker_diarization = get_diarization(prediction)
@@ -522,6 +535,27 @@ def benchmark(
522535
case_sensitive=False,
523536
),
524537
] = Subset.test,
538+
revision: Annotated[
539+
Optional[str],
540+
typer.Option(
541+
help="Pretrained pipeline revision.",
542+
),
543+
] = None,
544+
token: Annotated[
545+
Optional[str],
546+
typer.Argument(help="Huggingface token."),
547+
] = None,
548+
cache: Annotated[
549+
Optional[Path],
550+
typer.Option(
551+
help="Path to the folder where files downloaded from Huggingface are stored.",
552+
exists=True,
553+
dir_okay=True,
554+
file_okay=False,
555+
writable=True,
556+
resolve_path=True,
557+
),
558+
] = None,
525559
device: Annotated[
526560
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
527561
] = Device.AUTO,
@@ -538,17 +572,6 @@ def benchmark(
538572
num_speakers: Annotated[
539573
NumSpeakers, typer.Option(help="Number of speakers (oracle or auto)")
540574
] = NumSpeakers.AUTO,
541-
cache: Annotated[
542-
Path,
543-
typer.Option(
544-
help="Path to the folder where files downloaded from Huggingface hub are stored.",
545-
exists=True,
546-
dir_okay=True,
547-
file_okay=False,
548-
writable=True,
549-
resolve_path=True,
550-
),
551-
] = None,
552575
optimize: Annotated[
553576
bool,
554577
typer.Option(
@@ -562,10 +585,7 @@ def benchmark(
562585
),
563586
] = False,
564587
per_file: Annotated[
565-
bool,
566-
typer.Option(
567-
help="Save one RTTM/JSON file per processed audio file."
568-
)
588+
bool, typer.Option(help="Save one RTTM/JSON file per processed audio file.")
569589
] = False,
570590
):
571591
"""
@@ -578,7 +598,9 @@ def benchmark(
578598
"""
579599

580600
# load pretrained pipeline
581-
pretrained_pipeline = Pipeline.from_pretrained(pipeline, cache_dir=cache)
601+
pretrained_pipeline = Pipeline.from_pretrained(
602+
pipeline, revision=revision, token=token, cache_dir=cache,
603+
)
582604
if pretrained_pipeline is None:
583605
print(f"Could not load pretrained pipeline from {pipeline}.")
584606
raise typer.exit(code=1)
@@ -808,15 +830,19 @@ def benchmark(
808830
yaml.dump({"min_duration_off": best_min_duration_off}, yml)
809831

810832
if not per_file:
811-
optimized_rttm_file = into / f"{benchmark_name}.OptimizedMinDurationOff.rttm"
833+
optimized_rttm_file = (
834+
into / f"{benchmark_name}.OptimizedMinDurationOff.rttm"
835+
)
812836

813837
# make sure we don't overwrite previous results
814838
if optimized_rttm_file.exists():
815839
raise FileExistsError(f"{optimized_rttm_file} already exists.")
816840

817841
for file in files:
818842
if per_file:
819-
optimized_rttm_file = rttm_dir / f"{file['uri']}.OptimizedMinDurationOff.rttm"
843+
optimized_rttm_file = (
844+
rttm_dir / f"{file['uri']}.OptimizedMinDurationOff.rttm"
845+
)
820846

821847
with open(optimized_rttm_file, "w" if per_file else "a") as rttm:
822848
file["best_speaker_diarization"].write_rttm(rttm)
@@ -851,11 +877,11 @@ def strip(
851877
"""
852878

853879
keys = [
854-
"pytorch-lightning_version", # * pytorch-lightning needs
855-
"hparams_name", # those values to initialize
856-
"hyper_parameters", # the model architecture
857-
"state_dict", # * actual weights
858-
"pyannote.audio", # * pyannote.audio dependencies
880+
"pytorch-lightning_version", # * pytorch-lightning needs
881+
"hparams_name", # those values to initialize
882+
"hyper_parameters", # the model architecture
883+
"state_dict", # * actual weights
884+
"pyannote.audio", # * pyannote.audio dependencies
859885
]
860886

861887
old_checkpoint = torch.load(

src/pyannote/audio/core/calibration.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,34 @@ class Calibration(IsotonicRegression):
4040
def __init__(self):
4141
super().__init__(y_min=0.0, y_max=1.0, increasing="auto", out_of_bounds="clip")
4242

43+
def safe_transform(
44+
self,
45+
values: np.ndarray,
46+
nan_value: float = 2.0,
47+
) -> np.ndarray:
48+
"""Apply calibration handling NaN values and any shape gracefully
49+
50+
Parameters
51+
----------
52+
values : np.ndarray
53+
Values to calibrate
54+
nan_value : float, optional
55+
Value to use in place of NaN values during calibration. Default is 2.0.
56+
57+
Returns
58+
-------
59+
calibrated_values : np.ndarray
60+
Calibrated values
61+
"""
62+
# temporarily replace NaN values with `nan_value` so `transform()` does not fail
63+
transformed = np.nan_to_num(values.reshape(-1), nan=nan_value)
64+
65+
# apply calibration
66+
transformed: np.ndarray = self.transform(transformed)
67+
68+
# recover original shape
69+
return transformed.reshape(values.shape)
70+
4371
def save(self, path: str):
4472
"""Save fitted calibration to disk
4573

src/pyannote/audio/core/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,9 @@ def default_map_location(storage, loc):
599599
map_location = default_map_location
600600

601601
# load checkpoint using lightning
602-
loaded_checkpoint = pl_load(path_to_model_checkpoint, map_location=map_location)
602+
loaded_checkpoint = pl_load(
603+
path_to_model_checkpoint, map_location=map_location, weights_only=False
604+
)
603605

604606
# check that the checkpoint is compatible with the current version
605607
versions = loaded_checkpoint["pyannote.audio"]["versions"]
@@ -620,6 +622,7 @@ def default_map_location(storage, loc):
620622
path_to_model_checkpoint,
621623
map_location=map_location,
622624
strict=strict,
625+
weights_only=False,
623626
**kwargs,
624627
)
625628
except RuntimeError as e:
@@ -638,6 +641,7 @@ def default_map_location(storage, loc):
638641
path_to_model_checkpoint,
639642
map_location=map_location,
640643
strict=False,
644+
weights_only=False,
641645
**kwargs,
642646
)
643647
return model

0 commit comments

Comments
 (0)