diff --git a/CHANGELOG.md b/CHANGELOG.md index c2eb1c22b..3c8157462 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,19 +2,28 @@ ## develop +### Fixes + +- fix(task): fix wrong train/development split when training with (some) meta-protocols ([#1709](https://github.com/pyannote/pyannote-audio/issues/1709)) + +## Version 3.2.0 (2024-05-08) + ### New features - feat(task): add option to cache task training metadata to speed up training (with [@clement-pages](https://github.com/clement-pages/)) - feat(model): add `receptive_field`, `num_frames` and `dimension` to models (with [@Bilal-Rahou](https://github.com/Bilal-Rahou)) +- feat(model): add `fbank_only` property to `WeSpeaker` models - feat(util): add `Powerset.permutation_mapping` to help with permutation in powerset space (with [@FrenchKrab](https://github.com/FrenchKrab)) -- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE` +- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE` - feat(metric): add `reduce` option to `diarization_error_rate` metric (with [@Bilal-Rahou](https://github.com/Bilal-Rahou)) - feat(pipeline): add `Waveform` and `SampleRate` preprocessors ### Fixes -- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab)) -- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab)) +- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab)) +- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab)) +- fix(hook): fix `torch.Tensor` support in `ArtifactHook` +- fix(doc): fix typo in `Powerset` docstring (with [@lukasstorck](https://github.com/lukasstorck)) ### Improvements @@ -24,11 +33,15 @@ - improve(doc): update tutorials (with [@clement-pages](https://github.com/clement-pages/)) - improve(io): remove the misleading mentions of `numpy.ndarray` as it's not supported (with [@Purfview](https://github.com/Purfview)) -## Breaking changes +### Breaking changes - BREAKING(model): get rid of `Model.example_output` in favor of `num_frames` method, `receptive_field` property, and `dimension` property - BREAKING(task): custom tasks need to be updated (see "Add your own task" tutorial) +### Community contributions + +- community: add tutorial for offline use of `pyannote/speaker-diarization-3.1` (by [@simonottenhauskenbun](https://github.com/simonottenhauskenbun)) + ## Version 3.1.1 (2023-12-01) ### TL;DR diff --git a/README.md b/README.md index 50dfeb286..abef6de01 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -Using `pyannote.audio` open-source toolkit in production? -Make the most of it thanks to our [consulting services](https://herve.niderb.fr/consulting.html). +Using `pyannote.audio` open-source toolkit in production? +Consider switching to [pyannoteAI](https://www.pyannote.ai) for better and faster options. # `pyannote.audio` speaker diarization toolkit @@ -79,21 +79,21 @@ for turn, _, speaker in diarization.itertracks(yield_label=True): Out of the box, `pyannote.audio` speaker diarization [pipeline](https://hf.co/pyannote/speaker-diarization-3.1) v3.1 is expected to be much better (and faster) than v2.x. Those numbers are diarization error rates (in %): -| Benchmark | [v2.1](https://hf.co/pyannote/speaker-diarization-2.1) | [v3.1](https://hf.co/pyannote/speaker-diarization-3.1) | [Premium](https://forms.office.com/e/GdqwVgkZ5C) | -| ---------------------- | ------ | ------ | --------- | -| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 14.1 | 12.2 | 11.9 | -| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 27.4 | 24.4 | 22.5 | -| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.9 | 18.8 | 16.6 | -| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 27.1 | 22.4 | 20.9 | -| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 66.3 | 50.0 | 39.8 | -| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 31.6 | 28.4 | 22.2 | -| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 26.9 | 21.7 | 17.2 | -| [Earnings21](https://github.com/revdotcom/speech-datasets) | 17.0 | 9.4 | 9.0 | -| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 61.5 | 51.2 | 43.8 | -| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 32.8 | 25.3 | 19.8 | -| [RAMC](https://www.openslr.org/123/) | 22.5 | 22.2 | 18.4 | -| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 8.2 | 7.8 | 7.6 | -| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.3 | 9.4 | +| Benchmark | [v2.1](https://hf.co/pyannote/speaker-diarization-2.1) | [v3.1](https://hf.co/pyannote/speaker-diarization-3.1) | [pyannoteAI](https://www.pyannote.ai) | +| --------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------ | ------------------------------------------------------ | ------------------------------------------------ | +| [AISHELL-4](https://arxiv.org/abs/2104.03603) | 14.1 | 12.2 | 11.9 | +| [AliMeeting](https://www.openslr.org/119/) (channel 1) | 27.4 | 24.4 | 22.5 | +| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (IHM) | 18.9 | 18.8 | 16.6 | +| [AMI](https://groups.inf.ed.ac.uk/ami/corpus/) (SDM) | 27.1 | 22.4 | 20.9 | +| [AVA-AVD](https://arxiv.org/abs/2111.14448) | 66.3 | 50.0 | 39.8 | +| [CALLHOME](https://catalog.ldc.upenn.edu/LDC2001S97) ([part 2](https://github.com/BUTSpeechFIT/CALLHOME_sublists/issues/1)) | 31.6 | 28.4 | 22.2 | +| [DIHARD 3](https://catalog.ldc.upenn.edu/LDC2022S14) ([full](https://arxiv.org/abs/2012.01477)) | 26.9 | 21.7 | 17.2 | +| [Earnings21](https://github.com/revdotcom/speech-datasets) | 17.0 | 9.4 | 9.0 | +| [Ego4D](https://arxiv.org/abs/2110.07058) (dev.) | 61.5 | 51.2 | 43.8 | +| [MSDWild](https://github.com/X-LANCE/MSDWILD) | 32.8 | 25.3 | 19.8 | +| [RAMC](https://www.openslr.org/123/) | 22.5 | 22.2 | 18.4 | +| [REPERE](https://www.islrn.org/resources/360-758-359-485-0/) (phase2) | 8.2 | 7.8 | 7.6 | +| [VoxConverse](https://github.com/joonson/voxconverse) (v0.3) | 11.2 | 11.3 | 9.4 | [Diarization error rate](http://pyannote.github.io/pyannote-metrics/reference.html#diarization) (in %) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 0a61e2a6f..974f43a67 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -362,12 +362,13 @@ def prepare_data(self): if self.has_validation: files_iter = itertools.chain( - self.protocol.train(), self.protocol.development() + zip(itertools.repeat("train"), self.protocol.train()), + zip(itertools.repeat("development"), self.protocol.development()), ) else: - files_iter = self.protocol.train() + files_iter = zip(itertools.repeat("train"), self.protocol.train()) - for file_id, file in enumerate(files_iter): + for file_id, (subset, file) in enumerate(files_iter): # gather metadata and update metadata_unique_values so that each metadatum # (e.g. source database or label) is represented by an integer. metadatum = dict() @@ -378,7 +379,8 @@ def prepare_data(self): metadatum["database"] = metadata_unique_values["database"].index( file["database"] ) - metadatum["subset"] = Subsets.index(file["subset"]) + + metadatum["subset"] = Subsets.index(subset) # keep track of label scope (file, database, or global) metadatum["scope"] = Scopes.index(file["scope"]) diff --git a/pyannote/audio/models/blocks/pooling.py b/pyannote/audio/models/blocks/pooling.py index 22d736a03..dc31bea8e 100644 --- a/pyannote/audio/models/blocks/pooling.py +++ b/pyannote/audio/models/blocks/pooling.py @@ -26,53 +26,53 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange -class StatsPool(nn.Module): - """Statistics pooling +def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """Helper function to compute statistics pooling - Compute temporal mean and (unbiased) standard deviation - and returns their concatenation. + Assumes that weights are already interpolated to match the number of frames + in sequences and that they encode the activation of only one speaker. - Reference - --------- - https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + Parameters + ---------- + sequences : (batch, features, frames) torch.Tensor + Sequences of features. + weights : (batch, frames) torch.Tensor + (Already interpolated) weights. + Returns + ------- + output : (batch, 2 * features) torch.Tensor + Concatenation of mean and (unbiased) standard deviation. """ - def _pool(self, sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - """Helper function to compute statistics pooling + weights = weights.unsqueeze(dim=1) + # (batch, 1, frames) - Assumes that weights are already interpolated to match the number of frames - in sequences and that they encode the activation of only one speaker. + v1 = weights.sum(dim=2) + 1e-8 + mean = torch.sum(sequences * weights, dim=2) / v1 - Parameters - ---------- - sequences : (batch, features, frames) torch.Tensor - Sequences of features. - weights : (batch, frames) torch.Tensor - (Already interpolated) weights. + dx2 = torch.square(sequences - mean.unsqueeze(2)) + v2 = torch.square(weights).sum(dim=2) - Returns - ------- - output : (batch, 2 * features) torch.Tensor - Concatenation of mean and (unbiased) standard deviation. - """ + var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) + std = torch.sqrt(var) - weights = weights.unsqueeze(dim=1) - # (batch, 1, frames) + return torch.cat([mean, std], dim=1) - v1 = weights.sum(dim=2) + 1e-8 - mean = torch.sum(sequences * weights, dim=2) / v1 - dx2 = torch.square(sequences - mean.unsqueeze(2)) - v2 = torch.square(weights).sum(dim=2) +class StatsPool(nn.Module): + """Statistics pooling - var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) - std = torch.sqrt(var) + Compute temporal mean and (unbiased) standard deviation + and returns their concatenation. - return torch.cat([mean, std], dim=1) + Reference + --------- + https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + + """ def forward( self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None @@ -112,17 +112,20 @@ def forward( has_speaker_dimension = True # interpolate weights if needed - _, _, num_frames = sequences.shape - _, _, num_weights = weights.shape + _, _, num_frames = sequences.size() + _, num_speakers, num_weights = weights.size() if num_frames != num_weights: warnings.warn( f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers." ) weights = F.interpolate(weights, size=num_frames, mode="nearest") - output = rearrange( - torch.vmap(self._pool, in_dims=(None, 1))(sequences, weights), - "speakers batch features -> batch speakers features", + output = torch.stack( + [ + _pool(sequences, weights[:, speaker, :]) + for speaker in range(num_speakers) + ], + dim=1, ) if not has_speaker_dimension: diff --git a/pyannote/audio/models/blocks/sincnet.py b/pyannote/audio/models/blocks/sincnet.py index b46549bb3..2a085201c 100644 --- a/pyannote/audio/models/blocks/sincnet.py +++ b/pyannote/audio/models/blocks/sincnet.py @@ -122,12 +122,14 @@ def receptive_field_size(self, num_frames: int = 1) -> int: kernel_size = [251, 3, 5, 3, 5, 3] stride = [self.stride, 3, 1, 3, 1, 3] + padding = [0, 0, 0, 0, 0, 0] dilation = [1, 1, 1, 1, 1, 1] return multi_conv_receptive_field_size( num_frames, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, ) diff --git a/pyannote/audio/models/embedding/debug.py b/pyannote/audio/models/embedding/debug.py index a5e862a24..b09283908 100644 --- a/pyannote/audio/models/embedding/debug.py +++ b/pyannote/audio/models/embedding/debug.py @@ -31,11 +31,6 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task -from pyannote.audio.utils.receptive_field import ( - conv1d_num_frames, - conv1d_receptive_field_center, - conv1d_receptive_field_size, -) class SimpleEmbeddingModel(Model): @@ -87,13 +82,10 @@ def num_frames(self, num_samples: int) -> int: n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft center = self.mfcc.MelSpectrogram.spectrogram.center - return conv1d_num_frames( - num_samples=num_samples, - kernel_size=n_fft, - stride=hop_length, - padding=n_fft // 2 if center else 0, - dilation=1, - ) + if center: + return 1 + num_samples // hop_length + else: + return 1 + (num_samples - n_fft) // hop_length def receptive_field_size(self, num_frames: int = 1) -> int: """Compute size of receptive field @@ -111,10 +103,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft - - return conv1d_receptive_field_size( - num_frames, kernel_size=n_fft, stride=hop_length, dilation=1 - ) + return n_fft + (num_frames - 1) * hop_length def receptive_field_center(self, frame: int = 0) -> int: """Compute center of receptive field @@ -134,13 +123,10 @@ def receptive_field_center(self, frame: int = 0) -> int: n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft center = self.mfcc.MelSpectrogram.spectrogram.center - return conv1d_receptive_field_center( - frame=frame, - kernel_size=n_fft, - stride=hop_length, - padding=n_fft // 2 if center else 0, - dilation=1, - ) + if center: + return frame * hop_length + else: + return frame * hop_length + n_fft // 2 @property def dimension(self) -> int: diff --git a/pyannote/audio/models/embedding/wespeaker/__init__.py b/pyannote/audio/models/embedding/wespeaker/__init__.py index e75779dda..be51196c1 100644 --- a/pyannote/audio/models/embedding/wespeaker/__init__.py +++ b/pyannote/audio/models/embedding/wespeaker/__init__.py @@ -25,6 +25,7 @@ from typing import Optional import torch +import torch.nn.functional as F import torchaudio.compliance.kaldi as kaldi from pyannote.audio.core.model import Model @@ -39,16 +40,33 @@ class BaseWeSpeakerResNet(Model): + """Base class for WeSpeaker's ResNet models + + Parameters + ---------- + fbank_centering_span : float, optional + Span of the fbank centering window (in seconds). + Defaults (None) to use whole input. + + See also + -------- + torchaudio.compliance.kaldi.fbank + + """ + def __init__( self, sample_rate: int = 16000, num_channels: int = 1, num_mel_bins: int = 80, - frame_length: int = 25, - frame_shift: int = 10, + frame_length: float = 25.0, # in milliseconds + frame_shift: float = 10.0, # in milliseconds + round_to_power_of_two: bool = True, + snip_edges: bool = True, dither: float = 0.0, window_type: str = "hamming", use_energy: bool = False, + fbank_centering_span: Optional[float] = None, task: Optional[Task] = None, ): super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task) @@ -60,21 +78,38 @@ def __init__( "frame_length", "frame_shift", "dither", + "round_to_power_of_two", + "snip_edges", "window_type", "use_energy", + "fbank_centering_span", ) self._fbank = partial( kaldi.fbank, num_mel_bins=self.hparams.num_mel_bins, frame_length=self.hparams.frame_length, + round_to_power_of_two=self.hparams.round_to_power_of_two, frame_shift=self.hparams.frame_shift, + snip_edges=self.hparams.snip_edges, dither=self.hparams.dither, sample_frequency=self.hparams.sample_rate, window_type=self.hparams.window_type, use_energy=self.hparams.use_energy, ) + @property + def fbank_only(self) -> bool: + """Whether to only extract fbank features""" + return getattr(self, "_fbank_only", False) + + @fbank_only.setter + def fbank_only(self, value: bool): + if hasattr(self, "receptive_field"): + del self.receptive_field + + self._fbank_only = value + def compute_fbank(self, waveforms: torch.Tensor) -> torch.Tensor: """Extract fbank features @@ -85,6 +120,7 @@ def compute_fbank(self, waveforms: torch.Tensor) -> torch.Tensor: Returns ------- fbank : (batch_size, num_frames, num_mel_bins) + fbank features Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50 """ @@ -98,11 +134,37 @@ def compute_fbank(self, waveforms: torch.Tensor) -> torch.Tensor: features = torch.vmap(self._fbank)(waveforms.to(fft_device)).to(device) - return features - torch.mean(features, dim=1, keepdim=True) + # center features with global average + if self.hparams.fbank_centering_span is None: + return features - torch.mean(features, dim=1, keepdim=True) + + # center features with running average + window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001) + step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001) + kernel_size = conv1d_num_frames( + num_samples=int( + self.hparams.fbank_centering_span * self.hparams.sample_rate + ), + kernel_size=window_size, + stride=step_size, + padding=0, + dilation=1, + ) + return features - F.avg_pool1d( + features.transpose(1, 2), + kernel_size=2 * (kernel_size // 2) + 1, + stride=1, + padding=kernel_size // 2, + count_include_pad=False, + ).transpose(1, 2) @property def dimension(self) -> int: """Dimension of output""" + + if self.fbank_only: + return self.hparams.num_mel_bins + return self.resnet.embed_dim @lru_cache @@ -122,6 +184,8 @@ def num_frames(self, num_samples: int) -> int: window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001) step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001) + # TODO: take round_to_power_of_two and snip_edges into account + num_frames = conv1d_num_frames( num_samples=num_samples, kernel_size=window_size, @@ -129,6 +193,10 @@ def num_frames(self, num_samples: int) -> int: padding=0, dilation=1, ) + + if self.fbank_only: + return num_frames + return self.resnet.num_frames(num_frames) def receptive_field_size(self, num_frames: int = 1) -> int: @@ -144,8 +212,13 @@ def receptive_field_size(self, num_frames: int = 1) -> int: receptive_field_size : int Receptive field size. """ + receptive_field_size = num_frames - receptive_field_size = self.resnet.receptive_field_size(receptive_field_size) + + if not self.fbank_only: + receptive_field_size = self.resnet.receptive_field_size( + receptive_field_size + ) window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001) step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001) @@ -154,6 +227,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: num_frames=receptive_field_size, kernel_size=window_size, stride=step_size, + padding=0, dilation=1, ) @@ -171,9 +245,11 @@ def receptive_field_center(self, frame: int = 0) -> int: Index of receptive field center. """ receptive_field_center = frame - receptive_field_center = self.resnet.receptive_field_center( - frame=receptive_field_center - ) + + if not self.fbank_only: + receptive_field_center = self.resnet.receptive_field_center( + frame=receptive_field_center + ) window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001) step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001) @@ -188,14 +264,79 @@ def receptive_field_center(self, frame: int = 0) -> int: def forward( self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None ) -> torch.Tensor: + """Extract speaker embeddings + + Parameters + ---------- + waveforms : torch.Tensor + Batch of waveforms with shape (batch, channel, sample) + weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional + Batch of weights passed to statistics pooling layer. + + Returns + ------- + embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor + Batch of embeddings. + """ + + fbank = self.compute_fbank(waveforms) + if self.fbank_only: + return fbank + + return self.resnet(fbank, weights=weights)[1] + + def forward_frames(self, waveforms: torch.Tensor) -> torch.Tensor: + """Extract frame-wise embeddings + + Parameters + ---------- + waveforms : torch.Tensor + Batch of waveforms with shape (batch, channel, sample) + + Returns + ------- + embeddings : (batch, ..., embedding_frames) torch.Tensor + Batch of frame-wise embeddings. + """ + fbank = self.compute_fbank(waveforms) + return self.resnet.forward_frames(fbank) + + def forward_embedding( + self, frames: torch.Tensor, weights: torch.Tensor = None + ) -> torch.Tensor: + """Extract speaker embeddings from frame-wise embeddings + + Parameters + ---------- + frames : torch.Tensor + Batch of frames with shape (batch, ..., embedding_frames). + weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional + Batch of weights passed to statistics pooling layer. + + Returns + ------- + embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor + Batch of embeddings. + """ + return self.resnet.forward_embedding(frames, weights=weights)[1] + + def forward( + self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Extract speaker embeddings Parameters ---------- waveforms : torch.Tensor Batch of waveforms with shape (batch, channel, sample) - weights : torch.Tensor, optional - Batch of weights with shape (batch, frame). + weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional + Batch of weights passed to statistics pooling layer. + + Returns + ------- + embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor + Batch of embeddings. """ fbank = self.compute_fbank(waveforms) diff --git a/pyannote/audio/models/embedding/wespeaker/resnet.py b/pyannote/audio/models/embedding/wespeaker/resnet.py index b64dd386d..4c9d5a5f0 100644 --- a/pyannote/audio/models/embedding/wespeaker/resnet.py +++ b/pyannote/audio/models/embedding/wespeaker/resnet.py @@ -124,6 +124,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: num_frames, kernel_size=[3, 3], stride=[self.stride, 1], + padding=[1, 1], dilation=[1, 1], ) @@ -189,6 +190,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: num_frames, kernel_size=[1, 3, 1], stride=[1, self.stride, 1], + padding=[0, 1, 0], dilation=[1, 1, 1], ) @@ -305,6 +307,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: num_frames=receptive_field_size, kernel_size=3, stride=1, + padding=1, dilation=1, ) @@ -341,12 +344,64 @@ def receptive_field_center(self, frame: int = 0) -> int: return receptive_field_center - def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None): + def forward_frames(self, fbank: torch.Tensor) -> torch.Tensor: + """Extract frame-wise embeddings + + Parameters + ---------- + fbanks : (batch, frames, features) torch.Tensor + Batch of fbank features + + Returns + ------- + embeddings : (batch, ..., embedding_frames) torch.Tensor + Batch of frame-wise embeddings. + + """ + fbank = fbank.permute(0, 2, 1) # (B,T,F) => (B,F,T) + fbank = fbank.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(fbank))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + return out + + def forward_embedding( + self, frames: torch.Tensor, weights: torch.Tensor = None + ) -> torch.Tensor: + """Extract speaker embeddings + + Parameters + ---------- + frames : torch.Tensor + Batch of frames with shape (batch, ..., embedding_frames). + weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional + Batch of weights passed to statistics pooling layer. + + Returns + ------- + embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor + Batch of embeddings. """ + stats = self.pool(frames, weights=weights) + + embed_a = self.seg_1(stats) + if self.two_emb_layer: + out = F.relu(embed_a) + out = self.seg_bn_1(out) + embed_b = self.seg_2(out) + return embed_a, embed_b + else: + return torch.tensor(0.0), embed_a + + def forward(self, fbank: torch.Tensor, weights: Optional[torch.Tensor] = None): + """Extract speaker embeddings + Parameters ---------- - x : (batch, frames, features) torch.Tensor + fbank : (batch, frames, features) torch.Tensor Batch of features weights : (batch, frames) torch.Tensor, optional Batch of weights @@ -355,10 +410,9 @@ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None): ------- embedding : (batch, embedding_dim) torch.Tensor """ - x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) - - x = x.unsqueeze_(1) - out = F.relu(self.bn1(self.conv1(x))) + fbank = fbank.permute(0, 2, 1) # (B,T,F) => (B,F,T) + fbank = fbank.unsqueeze_(1) + out = F.relu(self.bn1(self.conv1(fbank))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) diff --git a/pyannote/audio/models/embedding/xvector.py b/pyannote/audio/models/embedding/xvector.py index 00916fbd0..3161876e3 100644 --- a/pyannote/audio/models/embedding/xvector.py +++ b/pyannote/audio/models/embedding/xvector.py @@ -33,9 +33,6 @@ from pyannote.audio.models.blocks.sincnet import SincNet from pyannote.audio.utils.params import merge_dict from pyannote.audio.utils.receptive_field import ( - conv1d_num_frames, - conv1d_receptive_field_center, - conv1d_receptive_field_size, multi_conv_num_frames, multi_conv_receptive_field_center, multi_conv_receptive_field_size, @@ -115,13 +112,10 @@ def num_frames(self, num_samples: int) -> int: n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft center = self.mfcc.MelSpectrogram.spectrogram.center - num_frames = conv1d_num_frames( - num_samples, - kernel_size=n_fft, - stride=hop_length, - dilation=1, - padding=n_fft // 2 if center else 0, - ) + if center: + num_frames = 1 + num_samples // hop_length + else: + num_frames = 1 + (num_samples - n_fft) // hop_length return multi_conv_num_frames( num_frames, @@ -155,13 +149,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft - - return conv1d_receptive_field_size( - num_frames=receptive_field_size, - kernel_size=n_fft, - stride=hop_length, - dilation=1, - ) + return n_fft + (receptive_field_size - 1) * hop_length def receptive_field_center(self, frame: int = 0) -> int: """Compute center of receptive field @@ -189,13 +177,10 @@ def receptive_field_center(self, frame: int = 0) -> int: n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft center = self.mfcc.MelSpectrogram.spectrogram.center - return conv1d_receptive_field_center( - frame=receptive_field_center, - kernel_size=n_fft, - stride=hop_length, - padding=n_fft // 2 if center else 0, - dilation=1, - ) + if center: + return receptive_field_center * hop_length + else: + return receptive_field_center * hop_length + n_fft // 2 def forward( self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None diff --git a/pyannote/audio/models/segmentation/SSeRiouSS.py b/pyannote/audio/models/segmentation/SSeRiouSS.py index ef550dfe1..b96464ab3 100644 --- a/pyannote/audio/models/segmentation/SSeRiouSS.py +++ b/pyannote/audio/models/segmentation/SSeRiouSS.py @@ -149,9 +149,12 @@ def __init__( self.lstm = nn.ModuleList( [ nn.LSTM( - wav2vec_dim - if i == 0 - else lstm["hidden_size"] * (2 if lstm["bidirectional"] else 1), + ( + wav2vec_dim + if i == 0 + else lstm["hidden_size"] + * (2 if lstm["bidirectional"] else 1) + ), **one_layer_lstm, ) for i in range(num_layers) @@ -246,6 +249,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: num_frames=receptive_field_size, kernel_size=conv_layer.kernel_size, stride=conv_layer.stride, + padding=conv_layer.conv.padding[0], dilation=conv_layer.conv.dilation[0], ) return receptive_field_size diff --git a/pyannote/audio/models/segmentation/debug.py b/pyannote/audio/models/segmentation/debug.py index 93c205b3d..ccac612a9 100644 --- a/pyannote/audio/models/segmentation/debug.py +++ b/pyannote/audio/models/segmentation/debug.py @@ -31,11 +31,6 @@ from pyannote.audio.core.model import Model from pyannote.audio.core.task import Task -from pyannote.audio.utils.receptive_field import ( - conv1d_num_frames, - conv1d_receptive_field_center, - conv1d_receptive_field_size, -) class SimpleSegmentationModel(Model): @@ -87,13 +82,10 @@ def num_frames(self, num_samples: int) -> int: n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft center = self.mfcc.MelSpectrogram.spectrogram.center - return conv1d_num_frames( - num_samples=num_samples, - kernel_size=n_fft, - stride=hop_length, - padding=n_fft // 2 if center else 0, - dilation=1, - ) + if center: + return 1 + num_samples // hop_length + else: + return 1 + (num_samples - n_fft) // hop_length def receptive_field_size(self, num_frames: int = 1) -> int: """Compute size of receptive field @@ -111,10 +103,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int: hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft - - return conv1d_receptive_field_size( - num_frames, kernel_size=n_fft, stride=hop_length, dilation=1 - ) + return n_fft + (num_frames - 1) * hop_length def receptive_field_center(self, frame: int = 0) -> int: """Compute center of receptive field @@ -134,13 +123,10 @@ def receptive_field_center(self, frame: int = 0) -> int: n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft center = self.mfcc.MelSpectrogram.spectrogram.center - return conv1d_receptive_field_center( - frame=frame, - kernel_size=n_fft, - stride=hop_length, - padding=n_fft // 2 if center else 0, - dilation=1, - ) + if center: + return frame * hop_length + else: + return frame * hop_length + n_fft // 2 @property def dimension(self) -> int: diff --git a/pyannote/audio/pipelines/utils/hook.py b/pyannote/audio/pipelines/utils/hook.py index 2a675d1c9..db6972e2e 100644 --- a/pyannote/audio/pipelines/utils/hook.py +++ b/pyannote/audio/pipelines/utils/hook.py @@ -24,6 +24,7 @@ from copy import deepcopy from typing import Any, Mapping, Optional, Text +import torch from rich.progress import ( BarColumn, Progress, @@ -75,6 +76,9 @@ def __call__( ): return + if isinstance(step_artifact, torch.Tensor): + step_artifact = step_artifact.numpy(force=True) + file.setdefault(self.file_key, dict())[step_name] = deepcopy(step_artifact) diff --git a/pyannote/audio/utils/powerset.py b/pyannote/audio/utils/powerset.py index 6a0716df9..23f921569 100644 --- a/pyannote/audio/utils/powerset.py +++ b/pyannote/audio/utils/powerset.py @@ -109,7 +109,7 @@ def to_multilabel(self, powerset: torch.Tensor, soft: bool = False) -> torch.Ten Soft predictions in "powerset" space. soft : bool, optional Return soft multi-label predictions. Defaults to False (i.e. hard predictions) - Assumes that `powerset` are "logits" (not "probabilities"). + Assumes that `powerset` are "log probabilities". Returns ------- diff --git a/pyannote/audio/utils/receptive_field.py b/pyannote/audio/utils/receptive_field.py index 0e484e4ad..420a62de0 100644 --- a/pyannote/audio/utils/receptive_field.py +++ b/pyannote/audio/utils/receptive_field.py @@ -69,7 +69,9 @@ def multi_conv_num_frames( return num_frames -def conv1d_receptive_field_size(num_frames=1, kernel_size=5, stride=1, dilation=1): +def conv1d_receptive_field_size( + num_frames=1, kernel_size=5, stride=1, padding=0, dilation=1 +): """Compute size of receptive field Parameters @@ -80,6 +82,8 @@ def conv1d_receptive_field_size(num_frames=1, kernel_size=5, stride=1, dilation= Kernel size stride : int Stride + padding : int + Padding dilation : int Dilation @@ -90,7 +94,7 @@ def conv1d_receptive_field_size(num_frames=1, kernel_size=5, stride=1, dilation= """ effective_kernel_size = 1 + (kernel_size - 1) * dilation - return effective_kernel_size + (num_frames - 1) * stride + return effective_kernel_size + (num_frames - 1) * stride - 2 * padding def multi_conv_receptive_field_size( @@ -102,11 +106,12 @@ def multi_conv_receptive_field_size( ) -> int: receptive_field_size = num_frames - for k, s, d in reversed(list(zip(kernel_size, stride, dilation))): + for k, s, p, d in reversed(list(zip(kernel_size, stride, padding, dilation))): receptive_field_size = conv1d_receptive_field_size( num_frames=receptive_field_size, kernel_size=k, stride=s, + padding=p, dilation=d, ) return receptive_field_size diff --git a/version.txt b/version.txt index 94ff29cc4..944880fa1 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -3.1.1 +3.2.0