From 5d56a11fd8340f34672ed53b5e32750feee820c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Herv=C3=A9=20BREDIN?= Date: Tue, 7 May 2024 17:16:14 +0200 Subject: [PATCH] chore: remove use of vmap in stats-pooling layer (#1706) --- pyannote/audio/models/blocks/pooling.py | 77 +++++++++++++------------ 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/pyannote/audio/models/blocks/pooling.py b/pyannote/audio/models/blocks/pooling.py index 22d736a03..dc31bea8e 100644 --- a/pyannote/audio/models/blocks/pooling.py +++ b/pyannote/audio/models/blocks/pooling.py @@ -26,53 +26,53 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange -class StatsPool(nn.Module): - """Statistics pooling +def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: + """Helper function to compute statistics pooling - Compute temporal mean and (unbiased) standard deviation - and returns their concatenation. + Assumes that weights are already interpolated to match the number of frames + in sequences and that they encode the activation of only one speaker. - Reference - --------- - https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + Parameters + ---------- + sequences : (batch, features, frames) torch.Tensor + Sequences of features. + weights : (batch, frames) torch.Tensor + (Already interpolated) weights. + Returns + ------- + output : (batch, 2 * features) torch.Tensor + Concatenation of mean and (unbiased) standard deviation. """ - def _pool(self, sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: - """Helper function to compute statistics pooling + weights = weights.unsqueeze(dim=1) + # (batch, 1, frames) - Assumes that weights are already interpolated to match the number of frames - in sequences and that they encode the activation of only one speaker. + v1 = weights.sum(dim=2) + 1e-8 + mean = torch.sum(sequences * weights, dim=2) / v1 - Parameters - ---------- - sequences : (batch, features, frames) torch.Tensor - Sequences of features. - weights : (batch, frames) torch.Tensor - (Already interpolated) weights. + dx2 = torch.square(sequences - mean.unsqueeze(2)) + v2 = torch.square(weights).sum(dim=2) - Returns - ------- - output : (batch, 2 * features) torch.Tensor - Concatenation of mean and (unbiased) standard deviation. - """ + var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) + std = torch.sqrt(var) - weights = weights.unsqueeze(dim=1) - # (batch, 1, frames) + return torch.cat([mean, std], dim=1) - v1 = weights.sum(dim=2) + 1e-8 - mean = torch.sum(sequences * weights, dim=2) / v1 - dx2 = torch.square(sequences - mean.unsqueeze(2)) - v2 = torch.square(weights).sum(dim=2) +class StatsPool(nn.Module): + """Statistics pooling - var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8) - std = torch.sqrt(var) + Compute temporal mean and (unbiased) standard deviation + and returns their concatenation. - return torch.cat([mean, std], dim=1) + Reference + --------- + https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + + """ def forward( self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None @@ -112,17 +112,20 @@ def forward( has_speaker_dimension = True # interpolate weights if needed - _, _, num_frames = sequences.shape - _, _, num_weights = weights.shape + _, _, num_frames = sequences.size() + _, num_speakers, num_weights = weights.size() if num_frames != num_weights: warnings.warn( f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers." ) weights = F.interpolate(weights, size=num_frames, mode="nearest") - output = rearrange( - torch.vmap(self._pool, in_dims=(None, 1))(sequences, weights), - "speakers batch features -> batch speakers features", + output = torch.stack( + [ + _pool(sequences, weights[:, speaker, :]) + for speaker in range(num_speakers) + ], + dim=1, ) if not has_speaker_dimension: