Skip to content

Commit

Permalink
Merge branch 'develop' into feat/skip-embedding-when-max-speakers-is-1
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored May 19, 2024
2 parents 657e8ba + f1a6db2 commit 9044840
Show file tree
Hide file tree
Showing 18 changed files with 402 additions and 164 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ contact_links:
about: Using pyannote.audio in production? Make the most of it thanks to our consulting services.

- name: Premium models
url: https://forms.gle/eKhn7H2zTa68sMMx8
url: https://forms.office.com/e/GdqwVgkZ5C
about: We are considering selling premium models, extensions, or services around pyannote.audio.
25 changes: 24 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,24 @@

### New features

- feat(io): add option to select torchaudio `backend`

### Fixes

- fix(task): fix wrong train/development split when training with (some) meta-protocols ([#1709](https://github.com/pyannote/pyannote-audio/issues/1709))

### Improvements

- improve(io): when available, default to using `soundfile` backend
- improve(pipeline): do not extract embeddings when `max_speakers` is set to 1

## Version 3.2.0 (2024-05-08)

### New features

- feat(task): add option to cache task training metadata to speed up training (with [@clement-pages](https://github.com/clement-pages/))
- feat(model): add `receptive_field`, `num_frames` and `dimension` to models (with [@Bilal-Rahou](https://github.com/Bilal-Rahou))
- feat(model): add `fbank_only` property to `WeSpeaker` models
- feat(util): add `Powerset.permutation_mapping` to help with permutation in powerset space (with [@FrenchKrab](https://github.com/FrenchKrab))
- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE`
- feat(metric): add `reduce` option to `diarization_error_rate` metric (with [@Bilal-Rahou](https://github.com/Bilal-Rahou))
Expand All @@ -15,6 +31,9 @@

- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(hook): fix `torch.Tensor` support in `ArtifactHook`
- fix(doc): fix typo in `Powerset` docstring (with [@lukasstorck](https://github.com/lukasstorck))
- fix(doc): remove mention of unsupported `numpy.ndarray` waveform (with [@Purfview](https://github.com/Purfview))

### Improvements

Expand All @@ -24,11 +43,15 @@
- improve(io): switch to `torchaudio >= 2.2.0`
- improve(doc): update tutorials (with [@clement-pages](https://github.com/clement-pages/))

## Breaking changes
### Breaking changes

- BREAKING(model): get rid of `Model.example_output` in favor of `num_frames` method, `receptive_field` property, and `dimension` property
- BREAKING(task): custom tasks need to be updated (see "Add your own task" tutorial)

### Community contributions

- community: add tutorial for offline use of `pyannote/speaker-diarization-3.1` (by [@simonottenhauskenbun](https://github.com/simonottenhauskenbun))

## Version 3.1.1 (2023-12-01)

### TL;DR
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Using `pyannote.audio` open-source toolkit in production?
Make the most of it thanks to our [consulting services](https://herve.niderb.fr/consulting.html).
Using `pyannote.audio` open-source toolkit in production?
Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options.

# `pyannote.audio` speaker diarization toolkit

Expand Down Expand Up @@ -79,21 +79,21 @@ for turn, _, speaker in diarization.itertracks(yield_label=True):
Out of the box, `pyannote.audio` speaker diarization [pipeline](https://hf.co/pyannote/speaker-diarization-3.1) v3.1 is expected to be much better (and faster) than v2.x.
Those numbers are diarization error rates (in %):

| Benchmark | [v2.1](https://hf.co/pyannote/speaker-diarization-2.1) | [v3.1](https://hf.co/pyannote/speaker-diarization-3.1) | [Premium](https://forms.gle/eKhn7H2zTa68sMMx8) |
| ---------------------- | ------ | ------ | --------- |
| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 14.1 | 12.2 | 11.9 |
| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 27.4 | 24.4 | 22.5 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.9 | 18.8 | 16.6 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 27.1 | 22.4 | 20.9 |
| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 66.3 | 50.0 | 39.8 |
| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 31.6 | 28.4 | 22.2 |
| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 26.9 | 21.7 | 17.2 |
| [Earnings21](https://github.com/revdotcom/speech-datasets) | 17.0 | 9.4 | 9.0 |
| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 61.5 | 51.2 | 43.8 |
| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 32.8 | 25.3 | 19.8 |
| [RAMC](https://www.openslr.org/123/) | 22.5 | 22.2 | 18.4 |
| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 8.2 | 7.8 | 7.6 |
| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.3 | 9.4 |
| Benchmark | [v2.1](https://hf.co/pyannote/speaker-diarization-2.1) | [v3.1](https://hf.co/pyannote/speaker-diarization-3.1) | [pyannoteAI](https://www.pyannote.ai) |
| --------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ | ------------------------------------------------------ | ------------------------------------------------ |
| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 14.1 | 12.2 | 11.9 |
| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 27.4 | 24.4 | 22.5 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.9 | 18.8 | 16.6 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 27.1 | 22.4 | 20.9 |
| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 66.3 | 50.0 | 39.8 |
| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 31.6 | 28.4 | 22.2 |
| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 26.9 | 21.7 | 17.2 |
| [Earnings21](https://github.com/revdotcom/speech-datasets) | 17.0 | 9.4 | 9.0 |
| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 61.5 | 51.2 | 43.8 |
| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 32.8 | 25.3 | 19.8 |
| [RAMC](https://www.openslr.org/123/) | 22.5 | 22.2 | 18.4 |
| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 8.2 | 7.8 | 7.6 |
| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.3 | 9.4 |

[Diarization error rate](http://pyannote.github.io/pyannote-metrics/reference.html#diarization) (in %)

Expand Down
54 changes: 44 additions & 10 deletions pyannote/audio/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,41 @@
- a "IOBase" instance with "read" and "seek" support: open("audio.wav", "rb")
- a "Mapping" with any of the above as "audio" key: {"audio": ...}
- a "Mapping" with both "waveform" and "sample_rate" key:
{"waveform": (channel, time) numpy.ndarray or torch.Tensor, "sample_rate": 44100}
{"waveform": (channel, time) torch.Tensor, "sample_rate": 44100}
For last two options, an additional "channel" key can be provided as a zero-indexed
integer to load a specific channel: {"audio": "stereo.wav", "channel": 0}
"""


def get_torchaudio_info(file: AudioFile):
def get_torchaudio_info(
file: AudioFile, backend: str = None
) -> torchaudio.AudioMetaData:
"""Protocol preprocessor used to cache output of torchaudio.info
This is useful to speed future random access to this file, e.g.
in dataloaders using Audio.crop a lot....
Parameters
----------
file : AudioFile
backend : str
torchaudio backend to use. Defaults to 'soundfile' if available,
or the first available backend.
Returns
-------
info : torchaudio.AudioMetaData
Audio file metadata
"""

info = torchaudio.info(file["audio"])
if not backend:
backends = (
torchaudio.list_audio_backends()
) # e.g ['ffmpeg', 'soundfile', 'sox']
backend = "soundfile" if "soundfile" in backends else backends[0]

info = torchaudio.info(file["audio"], backend=backend)

# rewind if needed
if isinstance(file["audio"], IOBase):
Expand All @@ -82,6 +102,9 @@ class Audio:
In case of multi-channel audio, convert to single-channel audio
using one of the following strategies: select one channel at
'random' or 'downmix' by averaging all channels.
backend : str
torchaudio backend to use. Defaults to 'soundfile' if available,
or the first available backend.
Usage
-----
Expand Down Expand Up @@ -126,7 +149,7 @@ def validate_file(file: AudioFile) -> Mapping:
-------
validated_file : Mapping
{"audio": str, "uri": str, ...}
{"waveform": array or tensor, "sample_rate": int, "uri": str, ...}
{"waveform": tensor, "sample_rate": int, "uri": str, ...}
{"audio": file, "uri": "stream"} if `file` is an IOBase instance
Raises
Expand All @@ -148,7 +171,7 @@ def validate_file(file: AudioFile) -> Mapping:
raise ValueError(AudioFileDocString)

if "waveform" in file:
waveform: Union[np.ndarray, Tensor] = file["waveform"]
waveform: Tensor = file["waveform"]
if len(waveform.shape) != 2 or waveform.shape[0] > waveform.shape[1]:
raise ValueError(
"'waveform' must be provided as a (channel, time) torch Tensor."
Expand Down Expand Up @@ -179,11 +202,19 @@ def validate_file(file: AudioFile) -> Mapping:

return file

def __init__(self, sample_rate=None, mono=None):
def __init__(self, sample_rate: int = None, mono=None, backend: str = None):
super().__init__()
self.sample_rate = sample_rate
self.mono = mono

if not backend:
backends = (
torchaudio.list_audio_backends()
) # e.g ['ffmpeg', 'soundfile', 'sox']
backend = "soundfile" if "soundfile" in backends else backends[0]

self.backend = backend

def downmix_and_resample(self, waveform: Tensor, sample_rate: int) -> Tensor:
"""Downmix and resample
Expand Down Expand Up @@ -244,7 +275,7 @@ def get_duration(self, file: AudioFile) -> float:
if "torchaudio.info" in file:
info = file["torchaudio.info"]
else:
info = get_torchaudio_info(file)
info = get_torchaudio_info(file, backend=self.backend)

frames = info.num_frames
sample_rate = info.sample_rate
Expand Down Expand Up @@ -291,7 +322,7 @@ def __call__(self, file: AudioFile) -> Tuple[Tensor, int]:
sample_rate = file["sample_rate"]

elif "audio" in file:
waveform, sample_rate = torchaudio.load(file["audio"])
waveform, sample_rate = torchaudio.load(file["audio"], backend=self.backend)

# rewind if needed
if isinstance(file["audio"], IOBase):
Expand Down Expand Up @@ -349,7 +380,7 @@ def crop(
sample_rate = info.sample_rate

else:
info = get_torchaudio_info(file)
info = get_torchaudio_info(file, backend=self.backend)
frames = info.num_frames
sample_rate = info.sample_rate

Expand Down Expand Up @@ -401,7 +432,10 @@ def crop(
else:
try:
data, _ = torchaudio.load(
file["audio"], frame_offset=start_frame, num_frames=num_frames
file["audio"],
frame_offset=start_frame,
num_frames=num_frames,
backend=self.backend,
)
# rewind if needed
if isinstance(file["audio"], IOBase):
Expand Down
10 changes: 6 additions & 4 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,13 @@ def prepare_data(self):

if self.has_validation:
files_iter = itertools.chain(
self.protocol.train(), self.protocol.development()
zip(itertools.repeat("train"), self.protocol.train()),
zip(itertools.repeat("development"), self.protocol.development()),
)
else:
files_iter = self.protocol.train()
files_iter = zip(itertools.repeat("train"), self.protocol.train())

for file_id, file in enumerate(files_iter):
for file_id, (subset, file) in enumerate(files_iter):
# gather metadata and update metadata_unique_values so that each metadatum
# (e.g. source database or label) is represented by an integer.
metadatum = dict()
Expand All @@ -378,7 +379,8 @@ def prepare_data(self):
metadatum["database"] = metadata_unique_values["database"].index(
file["database"]
)
metadatum["subset"] = Subsets.index(file["subset"])

metadatum["subset"] = Subsets.index(subset)

# keep track of label scope (file, database, or global)
metadatum["scope"] = Scopes.index(file["scope"])
Expand Down
77 changes: 40 additions & 37 deletions pyannote/audio/models/blocks/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class StatsPool(nn.Module):
"""Statistics pooling
def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
Returns
-------
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""

def _pool(self, sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)

Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)

Returns
-------
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)

weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)
return torch.cat([mean, std], dim=1)

v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)
class StatsPool(nn.Module):
"""Statistics pooling
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
return torch.cat([mean, std], dim=1)
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
"""

def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -112,17 +112,20 @@ def forward(
has_speaker_dimension = True

# interpolate weights if needed
_, _, num_frames = sequences.shape
_, _, num_weights = weights.shape
_, _, num_frames = sequences.size()
_, num_speakers, num_weights = weights.size()
if num_frames != num_weights:
warnings.warn(
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
)
weights = F.interpolate(weights, size=num_frames, mode="nearest")

output = rearrange(
torch.vmap(self._pool, in_dims=(None, 1))(sequences, weights),
"speakers batch features -> batch speakers features",
output = torch.stack(
[
_pool(sequences, weights[:, speaker, :])
for speaker in range(num_speakers)
],
dim=1,
)

if not has_speaker_dimension:
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/models/blocks/sincnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ def receptive_field_size(self, num_frames: int = 1) -> int:

kernel_size = [251, 3, 5, 3, 5, 3]
stride = [self.stride, 3, 1, 3, 1, 3]
padding = [0, 0, 0, 0, 0, 0]
dilation = [1, 1, 1, 1, 1, 1]

return multi_conv_receptive_field_size(
num_frames,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)

Expand Down
Loading

0 comments on commit 9044840

Please sign in to comment.