Skip to content

Commit

Permalink
fix: fix receptive field of MFCC-based models
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed May 8, 2024
1 parent fb6bc36 commit 9ab70a0
Showing 1 changed file with 9 additions and 26 deletions.
35 changes: 9 additions & 26 deletions pyannote/audio/models/embedding/xvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
from pyannote.audio.models.blocks.sincnet import SincNet
from pyannote.audio.utils.params import merge_dict
from pyannote.audio.utils.receptive_field import (
conv1d_num_frames,
conv1d_receptive_field_center,
conv1d_receptive_field_size,
multi_conv_num_frames,
multi_conv_receptive_field_center,
multi_conv_receptive_field_size,
Expand Down Expand Up @@ -115,13 +112,10 @@ def num_frames(self, num_samples: int) -> int:
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center

num_frames = conv1d_num_frames(
num_samples,
kernel_size=n_fft,
stride=hop_length,
dilation=1,
padding=n_fft // 2 if center else 0,
)
if center:
num_frames = 1 + num_samples // hop_length
else:
num_frames = 1 + (num_samples - n_fft) // hop_length

return multi_conv_num_frames(
num_frames,
Expand Down Expand Up @@ -155,15 +149,7 @@ def receptive_field_size(self, num_frames: int = 1) -> int:

hop_length = self.mfcc.MelSpectrogram.spectrogram.hop_length
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center

return conv1d_receptive_field_size(
num_frames=receptive_field_size,
kernel_size=n_fft,
stride=hop_length,
padding=n_fft // 2 if center else 0,
dilation=1,
)
return n_fft + (receptive_field_size - 1) * hop_length

def receptive_field_center(self, frame: int = 0) -> int:
"""Compute center of receptive field
Expand Down Expand Up @@ -191,13 +177,10 @@ def receptive_field_center(self, frame: int = 0) -> int:
n_fft = self.mfcc.MelSpectrogram.spectrogram.n_fft
center = self.mfcc.MelSpectrogram.spectrogram.center

return conv1d_receptive_field_center(
frame=receptive_field_center,
kernel_size=n_fft,
stride=hop_length,
padding=n_fft // 2 if center else 0,
dilation=1,
)
if center:
return receptive_field_center * hop_length
else:
return receptive_field_center * hop_length + n_fft // 2

def forward(
self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None
Expand Down

0 comments on commit 9ab70a0

Please sign in to comment.