Skip to content

Commit

Permalink
Merge branch 'develop' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored May 17, 2024
2 parents adcbf0d + cad8bea commit c89fd15
Show file tree
Hide file tree
Showing 15 changed files with 340 additions and 155 deletions.
21 changes: 17 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@

## develop

### Fixes

- fix(task): fix wrong train/development split when training with (some) meta-protocols ([#1709](https://github.com/pyannote/pyannote-audio/issues/1709))

## Version 3.2.0 (2024-05-08)

### New features

- feat(task): add option to cache task training metadata to speed up training (with [@clement-pages](https://github.com/clement-pages/))
- feat(model): add `receptive_field`, `num_frames` and `dimension` to models (with [@Bilal-Rahou](https://github.com/Bilal-Rahou))
- feat(model): add `fbank_only` property to `WeSpeaker` models
- feat(util): add `Powerset.permutation_mapping` to help with permutation in powerset space (with [@FrenchKrab](https://github.com/FrenchKrab))
- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE`
- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE`
- feat(metric): add `reduce` option to `diarization_error_rate` metric (with [@Bilal-Rahou](https://github.com/Bilal-Rahou))
- feat(pipeline): add `Waveform` and `SampleRate` preprocessors

### Fixes

- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(hook): fix `torch.Tensor` support in `ArtifactHook`
- fix(doc): fix typo in `Powerset` docstring (with [@lukasstorck](https://github.com/lukasstorck))

### Improvements

Expand All @@ -24,11 +33,15 @@
- improve(doc): update tutorials (with [@clement-pages](https://github.com/clement-pages/))
- improve(io): remove the misleading mentions of `numpy.ndarray` as it's not supported (with [@Purfview](https://github.com/Purfview))

## Breaking changes
### Breaking changes

- BREAKING(model): get rid of `Model.example_output` in favor of `num_frames` method, `receptive_field` property, and `dimension` property
- BREAKING(task): custom tasks need to be updated (see "Add your own task" tutorial)

### Community contributions

- community: add tutorial for offline use of `pyannote/speaker-diarization-3.1` (by [@simonottenhauskenbun](https://github.com/simonottenhauskenbun))

## Version 3.1.1 (2023-12-01)

### TL;DR
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Using `pyannote.audio` open-source toolkit in production?
Make the most of it thanks to our [consulting services](https://herve.niderb.fr/consulting.html).
Using `pyannote.audio` open-source toolkit in production?
Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options.

# `pyannote.audio` speaker diarization toolkit

Expand Down Expand Up @@ -79,21 +79,21 @@ for turn, _, speaker in diarization.itertracks(yield_label=True):
Out of the box, `pyannote.audio` speaker diarization [pipeline](https://hf.co/pyannote/speaker-diarization-3.1) v3.1 is expected to be much better (and faster) than v2.x.
Those numbers are diarization error rates (in %):

| Benchmark | [v2.1](https://hf.co/pyannote/speaker-diarization-2.1) | [v3.1](https://hf.co/pyannote/speaker-diarization-3.1) | [Premium](https://forms.office.com/e/GdqwVgkZ5C) |
| ---------------------- | ------ | ------ | --------- |
| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 14.1 | 12.2 | 11.9 |
| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 27.4 | 24.4 | 22.5 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.9 | 18.8 | 16.6 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 27.1 | 22.4 | 20.9 |
| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 66.3 | 50.0 | 39.8 |
| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 31.6 | 28.4 | 22.2 |
| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 26.9 | 21.7 | 17.2 |
| [Earnings21](https://github.com/revdotcom/speech-datasets) | 17.0 | 9.4 | 9.0 |
| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 61.5 | 51.2 | 43.8 |
| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 32.8 | 25.3 | 19.8 |
| [RAMC](https://www.openslr.org/123/) | 22.5 | 22.2 | 18.4 |
| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 8.2 | 7.8 | 7.6 |
| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.3 | 9.4 |
| Benchmark | [v2.1](https://hf.co/pyannote/speaker-diarization-2.1) | [v3.1](https://hf.co/pyannote/speaker-diarization-3.1) | [pyannoteAI](https://www.pyannote.ai) |
| --------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ | ------------------------------------------------------ | ------------------------------------------------ |
| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 14.1 | 12.2 | 11.9 |
| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 27.4 | 24.4 | 22.5 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.9 | 18.8 | 16.6 |
| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 27.1 | 22.4 | 20.9 |
| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 66.3 | 50.0 | 39.8 |
| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 31.6 | 28.4 | 22.2 |
| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 26.9 | 21.7 | 17.2 |
| [Earnings21](https://github.com/revdotcom/speech-datasets) | 17.0 | 9.4 | 9.0 |
| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 61.5 | 51.2 | 43.8 |
| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 32.8 | 25.3 | 19.8 |
| [RAMC](https://www.openslr.org/123/) | 22.5 | 22.2 | 18.4 |
| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 8.2 | 7.8 | 7.6 |
| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.3 | 9.4 |

[Diarization error rate](http://pyannote.github.io/pyannote-metrics/reference.html#diarization) (in %)

Expand Down
10 changes: 6 additions & 4 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,13 @@ def prepare_data(self):

if self.has_validation:
files_iter = itertools.chain(
self.protocol.train(), self.protocol.development()
zip(itertools.repeat("train"), self.protocol.train()),
zip(itertools.repeat("development"), self.protocol.development()),
)
else:
files_iter = self.protocol.train()
files_iter = zip(itertools.repeat("train"), self.protocol.train())

for file_id, file in enumerate(files_iter):
for file_id, (subset, file) in enumerate(files_iter):
# gather metadata and update metadata_unique_values so that each metadatum
# (e.g. source database or label) is represented by an integer.
metadatum = dict()
Expand All @@ -378,7 +379,8 @@ def prepare_data(self):
metadatum["database"] = metadata_unique_values["database"].index(
file["database"]
)
metadatum["subset"] = Subsets.index(file["subset"])

metadatum["subset"] = Subsets.index(subset)

# keep track of label scope (file, database, or global)
metadatum["scope"] = Scopes.index(file["scope"])
Expand Down
77 changes: 40 additions & 37 deletions pyannote/audio/models/blocks/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class StatsPool(nn.Module):
"""Statistics pooling
def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
Returns
-------
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""

def _pool(self, sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)

Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)

Returns
-------
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)

weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)
return torch.cat([mean, std], dim=1)

v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)
class StatsPool(nn.Module):
"""Statistics pooling
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
return torch.cat([mean, std], dim=1)
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
"""

def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -112,17 +112,20 @@ def forward(
has_speaker_dimension = True

# interpolate weights if needed
_, _, num_frames = sequences.shape
_, _, num_weights = weights.shape
_, _, num_frames = sequences.size()
_, num_speakers, num_weights = weights.size()
if num_frames != num_weights:
warnings.warn(
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
)
weights = F.interpolate(weights, size=num_frames, mode="nearest")

output = rearrange(
torch.vmap(self._pool, in_dims=(None, 1))(sequences, weights),
"speakers batch features -> batch speakers features",
output = torch.stack(
[
_pool(sequences, weights[:, speaker, :])
for speaker in range(num_speakers)
],
dim=1,
)

if not has_speaker_dimension:
Expand Down
2 changes: 2 additions & 0 deletions pyannote/audio/models/blocks/sincnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,14 @@ def receptive_field_size(self, num_frames: int = 1) -> int:

kernel_size = [251, 3, 5, 3, 5, 3]
stride = [self.stride, 3, 1, 3, 1, 3]
padding = [0, 0, 0, 0, 0, 0]
dilation = [1, 1, 1, 1, 1, 1]

return multi_conv_receptive_field_size(
num_frames,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)

Expand Down
32 changes: 9 additions & 23 deletions pyannote/audio/models/embedding/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@

from pyannote.audio.core.model import Model
from pyannote.audio.core.task import Task
from pyannote.audio.utils.receptive_field import (
conv1d_num_frames,
conv1d_receptive_field_center,
conv1d_receptive_field_size,
)


class SimpleEmbeddingModel(Model):
Expand Down Expand Up @@ -87,13 +82,10 @@ def num_frames(self, num_samples: int) -> int:
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center

return conv1d_num_frames(
num_samples=num_samples,
kernel_size=n_fft,
stride=hop_length,
padding=n_fft // 2 if center else 0,
dilation=1,
)
if center:
return 1 + num_samples // hop_length
else:
return 1 + (num_samples - n_fft) // hop_length

def receptive_field_size(self, num_frames: int = 1) -> int:
"""Compute size of receptive field
Expand All @@ -111,10 +103,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int:

hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft

return conv1d_receptive_field_size(
num_frames, kernel_size=n_fft, stride=hop_length, dilation=1
)
return n_fft + (num_frames - 1) * hop_length

def receptive_field_center(self, frame: int = 0) -> int:
"""Compute center of receptive field
Expand All @@ -134,13 +123,10 @@ def receptive_field_center(self, frame: int = 0) -> int:
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center

return conv1d_receptive_field_center(
frame=frame,
kernel_size=n_fft,
stride=hop_length,
padding=n_fft // 2 if center else 0,
dilation=1,
)
if center:
return frame * hop_length
else:
return frame * hop_length + n_fft // 2

@property
def dimension(self) -> int:
Expand Down
Loading

0 comments on commit c89fd15

Please sign in to comment.