Skip to content

Commit

Permalink
Merge branch 'develop' into fix/fix-xvector-coreml-export
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Feb 4, 2025
2 parents efc8849 + 74705e9 commit 703f889
Show file tree
Hide file tree
Showing 37 changed files with 1,387 additions and 942 deletions.
49 changes: 46 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,57 @@

## develop

### TL;DR

#### Quality of life improvements

Models can now be stored alongside their pipelines in the same repository, streamlining gating mechanism:
- accept `pyannote/speaker-diarization-x.x` pipeline user agreement
- ~~accept `pyannote/segmentation-3.0` model user agreement~~
- ~~accept `pyannote/wespeaker-voxceleb-resnet34-LM` model user agreement~~
- load pipeline with `Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=True)`

#### Improve speech separation quality

Clipping and speaker/source alignment issues in speech separation pipeline have been fixed.

### Breaking changes

- BREAKING(hub): rename `use_auth_token` to `token`
- BREAKING(cache): rely on `huggingface_hub` caching directory (`PYANNOTE_CACHE` is no longer used)
- BREAKING(inference): `Inference` now only supports already instantiated models
- BREAKING(task): drop support for `multilabel` training in `SpeakerDiarization` task
- BREAKING(task): drop support for `warm_up` option in `SpeakerDiarization` task
- BREAKING(task): drop support for `weigh_by_cardinality` option in `SpeakerDiarization` task
- BREAKING(task): drop support for `vad_loss` option in `SpeakerDiarization` task

### New features

- feat: add `"hidden"` option to `ProgressHook`
- improve(hub): add support for pipeline repos that also include underlying models
- feat(clustering): add support for `k-means` clustering
- feat(model): add `wav2vec_frozen` option to freeze/unfreeze `wav2vec` in `SSeRiouSS` architecture
- feat(task): add support for manual optimization in `SpeakerDiarization` task
- feat(utils): add `hidden` option to `ProgressHook`
- feat(utils): add `FilterByNumberOfSpeakers` protocol files filter
- feat(core): add `Calibration` class to calibrate logits/distances into probabilities
- feat(metric): add `DetectionErrorRate`, `SegmentationErrorRate`, `DiarizationPrecision`, and `DiarizationRecall` metrics
- feat(cli): add CLI to apply (and benchmark) pretrained pipelines

### Fixes
### Improvements

- improve(model): improve WavLM (un)freezing support for `SSeRiouSS` architecture ([@clement-pages](https://github.com/clement-pages/))
- improve(task): improve `SpeakerDiarization` training with manual optimization ([@clement-pages](https://github.com/clement-pages/))

- fix: fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))
### Fixes

- fix(model): improve WavLM (un)freezing support for `ToTaToNet` architecture ([@clement-pages](https://github.com/clement-pages/))
- fix(separation): fix clipping issue in speech separation pipeline ([@joonaskalda](https://github.com/joonaskalda/))
- fix(separation): fix alignment between separated sources and diarization ([@Lebourdais](https://github.com/Lebourdais/) and [@clement-pages](https://github.com/clement-pages/))
- fix(separation): prevent leakage removal collar from being applied to diarization ([@clement-pages](https://github.com/clement-pages/))
- fix(separation): fix `PixIT` training with manual optimization ([@clement-pages](https://github.com/clement-pages/))
- fix(doc): fix link to pytorch ([@emmanuel-ferdman](https://github.com/emmanuel-ferdman/))
- fix(task): fix corner case with small (<9) number of validation samples ([@antoinelaurent](https://github.com/antoinelaurent/))
- fix(doc): fix default embedding in `SpeechSeparation` and `SpeakerDiarization` docstring ([@razi-tm](https://github.com/razi-tm/)).

## Version 3.3.2 (2024-09-11)

Expand Down
12 changes: 3 additions & 9 deletions FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ That being said, this whole authentication process does not prevent you from usi
<a name="does-pyannote-support-streaming-speaker-diarization"></a>
## Does pyannote support streaming speaker diarization?

**Short answer:** not out of the box, no.

**Long answer:** [I](https://herve.niderb.fr) am looking for sponsors to add this feature. In the meantime, [`diart`](https://github.com/juanmc2005/StreamingSpeakerDiarization) is the closest you can get from a streaming `pyannote.audio`. You might also be interested in [this blog post](https://herve.niderb.fr/fastpages/2021/08/05/Streaming-voice-activity-detection-with-pyannote.html) about streaming voice activity detection based on `pyannote.audio`.
pyannote does not, but [diart](https://github.com/juanmc2005/diart) (which is based on pyannote) does.

<a name="how-can-i-improve-performance"></a>
## How can I improve performance?

**Short answer:** [pyannoteAI](https://www.pyannote.ai) precision models are usually much more accurate (and faster).

**Long answer:**

1. Manually annotate dozens of conversations as precisely as possible.
Expand All @@ -40,15 +40,9 @@ That being said, this whole authentication process does not prevent you from usi
4. Follow [this recipe](https://github.com/pyannote/pyannote-audio/blob/develop/tutorials/adapting_pretrained_pipeline.ipynb).
5. Enjoy.

**Also:** [I am available](https://herve.niderb.fr) for contracting to help you with that.

<a name="how-does-one-spell-and-pronounce-pyannoteaudio"></a>
## How does one spell and pronounce pyannote.audio?

📝 Written in lower case: `pyannote.audio` (or `pyannote` if you are lazy). Not `PyAnnote` nor `PyAnnotate` (sic).
📢 Pronounced like the french verb `pianoter`. `pi` like in `pi`ano, not `py` like in `py`thon.
🎹 `pianoter` means to play the piano (hence the logo 🤯).

<hr>

Generated by [FAQtory](https://github.com/willmcgugan/faqtory)
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faste

# `pyannote.audio` speaker diarization toolkit

`pyannote.audio` is an open-source toolkit written in Python for speaker diarization. Based on [PyTorch](pytorch.org) machine learning framework, it comes with state-of-the-art [pretrained models and pipelines](https://hf.co/pyannote), that can be further finetuned to your own data for even better performance.
`pyannote.audio` is an open-source toolkit written in Python for speaker diarization. Based on [PyTorch](https://pytorch.org) machine learning framework, it comes with state-of-the-art [pretrained models and pipelines](https://hf.co/pyannote), that can be further finetuned to your own data for even better performance.

<p align="center">
<a href="https://www.youtube.com/watch?v=37R_R82lfwA"><img src="https://img.youtube.com/vi/37R_R82lfwA/0.jpg"></a>
Expand All @@ -18,16 +18,19 @@ Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faste

```python
from pyannote.audio import Pipeline
from pyannote.audio.pipelines.utils.hook import ProgressHook

pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="HUGGINGFACE_ACCESS_TOKEN_GOES_HERE")
token="HUGGINGFACE_ACCESS_TOKEN_GOES_HERE")

# send pipeline to GPU (when available)
import torch
pipeline.to(torch.device("cuda"))

# apply pretrained pipeline
diarization = pipeline("audio.wav")
# apply pretrained pipeline (with optional progress hook)
with ProgressHook() as hook:
diarization = pipeline("audio.wav", hook=hook)

# print the result
for turn, _, speaker in diarization.itertracks(yield_label=True):
Expand Down
221 changes: 221 additions & 0 deletions pyannote/audio/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#!/usr/bin/env python
# encoding: utf-8

# MIT License
#
# Copyright (c) 2024- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import sys
from contextlib import nullcontext
from enum import Enum
from pathlib import Path
from typing import Optional

import pyannote.database
import torch
import typer
from pyannote.core import Annotation
from typing_extensions import Annotated

from pyannote.audio import Pipeline


class Subset(str, Enum):
train = "train"
development = "development"
test = "test"


class Device(str, Enum):
CPU = "cpu"
CUDA = "cuda"
MPS = "mps"
AUTO = "auto"


def parse_device(device: Device) -> torch.device:
if device == Device.AUTO:
if torch.cuda.is_available():
device = Device.CUDA

elif torch.backends.mps.is_available():
device = Device.MPS

else:
device = Device.CPU

return torch.device(device.value)


app = typer.Typer()


# TODO: add option to download pretrained pipeline for later use without internet


@app.command("apply")
def apply(
pipeline: Annotated[
str,
typer.Argument(
help="Pretrained pipeline (e.g. pyannote/speaker-diarization-3.1)"
),
],
audio: Annotated[
Path,
typer.Argument(
help="Path to audio file",
exists=True,
file_okay=True,
readable=True,
),
],
into: Annotated[
Path,
typer.Option(
help="Path to file where results are saved.",
exists=False,
dir_okay=False,
file_okay=True,
writable=True,
resolve_path=True,
),
] = None,
device: Annotated[
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
] = Device.AUTO,
):
"""
Apply a pretrained PIPELINE to an AUDIO file
"""

# load pretrained pipeline
pretrained_pipeline = Pipeline.from_pretrained(pipeline)

# send pipeline to device
torch_device = parse_device(device)
pretrained_pipeline.to(torch_device)

# apply pipeline to audio file
prediction: Annotation = pretrained_pipeline(audio)

# save (or print) results
with open(into, "w") if into else nullcontext(sys.stdout) as rttm:
prediction.write_rttm(rttm)


@app.command("benchmark")
def benchmark(
pipeline: Annotated[
str,
typer.Argument(
help="Pretrained pipeline (e.g. pyannote/speaker-diarization-3.1)"
),
],
protocol: Annotated[
str,
typer.Argument(help="Benchmarked protocol"),
],
into: Annotated[
Path,
typer.Argument(
help="Directory into which benchmark results are saved",
exists=True,
dir_okay=True,
file_okay=False,
writable=True,
resolve_path=True,
),
],
subset: Annotated[
Subset,
typer.Option(
help="Benchmarked subset",
case_sensitive=False,
),
] = Subset.test,
device: Annotated[
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
] = Device.AUTO,
registry: Annotated[
Optional[Path],
typer.Option(
help="Loaded registry",
exists=True,
dir_okay=False,
file_okay=True,
readable=True,
),
] = None,
):
"""
Benchmark a pretrained PIPELINE
"""

# load pretrained pipeline
pretrained_pipeline = Pipeline.from_pretrained(pipeline)

# send pipeline to device
torch_device = parse_device(device)
pretrained_pipeline.to(torch_device)

# load pipeline metric (when available)
try:
metric = pretrained_pipeline.get_metric()
except NotImplementedError:
metric = None

# load protocol from (optional) registry
if registry:
pyannote.database.registry.load_database(registry)

loaded_protocol = pyannote.database.registry.get_protocol(
protocol, {"audio": pyannote.database.FileFinder()}
)

with open(into / f"{protocol}.{subset.value}.rttm", "w") as rttm:
for file in getattr(loaded_protocol, subset.value)():
prediction: Annotation = pretrained_pipeline(file)
prediction.write_rttm(rttm)
rttm.flush()

if metric is None:
continue

groundtruth = file.get("annotation", None)
if groundtruth is None:
continue

annotated = file.get("annotated", None)
_ = metric(groundtruth, prediction, uem=annotated)

if metric is None:
return

with open(into / f"{protocol}.{subset.value}.txt", "w") as txt:
txt.write(str(metric))

print(str(metric))


if __name__ == "__main__":
app()
4 changes: 1 addition & 3 deletions pyannote/audio/augmentation/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import torch
from torch import Tensor
from torch_audiomentations import Mix
from torch_audiomentations.augmentations.mix import Mix


class MixSpeakerDiarization(Mix):
Expand Down Expand Up @@ -85,7 +85,6 @@ def randomize_parameters(
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
):

batch_size, num_channels, num_samples = samples.shape
snr_distribution = torch.distributions.Uniform(
low=torch.tensor(
Expand Down Expand Up @@ -116,7 +115,6 @@ def randomize_parameters(
batch_size, dtype=torch.int64
)
for n in range(max_num_speakers + 1):

# indices of samples with exactly n speakers
samples_with_n_speakers = torch.where(num_speakers == n)[0]
num_samples_with_n_speakers = len(samples_with_n_speakers)
Expand Down
Loading

0 comments on commit 703f889

Please sign in to comment.