Skip to content

Commit

Permalink
update joint task training_step
Browse files Browse the repository at this point in the history
... and fix some bugs
  • Loading branch information
clement-pages committed May 15, 2024
1 parent 106bfc5 commit d3326b1
Showing 1 changed file with 50 additions and 55 deletions.
105 changes: 50 additions & 55 deletions pyannote/audio/tasks/joint_task/speaker_diarization_and_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import itertools
import math
import random
import warnings
from typing import Dict, Literal, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -139,6 +140,23 @@ def setup(self, stage="fit"):
]
)

# if there is no file dedicated to the embedding task
if self.alpha != 1.0 and len(embedding_classes) == 0:
self.num_dia_samples = self.batch_size
self.alpha = 1.0
warnings.warn(
"No class found for the speaker embedding task. Model will be trained on the speaker diarization task only."
)

if self.alpha != 0.0 and np.sum(global_scope_mask) == len(
self.prepared_data["annotations-segments"]
):
self.num_dia_samples = 0
self.alpha = 0.0
warnings.warn(
"No segment found for the speaker diarization task. Model will be trained on the speaker embedding task only."
)

speaker_diarization = Specifications(
duration=self.duration,
resolution=Resolution.FRAME,
Expand Down Expand Up @@ -363,33 +381,24 @@ def train__iter__helper(self, rng: random.Random, **filters):
# get the subset of embedding database files from training files
embedding_files_ids = file_ids[np.isin(file_ids, self.embedding_files_id)]

annotated_duration = self.prepared_data["audio-annotated"][file_ids]
# set duration of files for the embedding part to zero, in order to not
# drawn them for diarization part
annotated_duration[embedding_files_ids] = 0.0
if self.num_dia_samples > 0:
annotated_duration = self.prepared_data["audio-annotated"][file_ids]
# set duration of files for the embedding part to zero, in order to not
# drawn them for diarization part
annotated_duration[embedding_files_ids] = 0.0

# test if there is at least one file for the diarization subtask
# to prevent probabilities from summing to zero
if np.any(annotated_duration != 0.0):
cum_prob_annotated_duration = np.cumsum(
annotated_duration / np.sum(annotated_duration)
)
else:
# There is only files for the embedding subtask, so only train on
# this task
self.num_dia_samples = 0.0
self.alpha = 0.0

duration = self.duration
batch_size = self.batch_size

# use original order for the first run of the shuffled classes list:
shuffled_embedding_classes = list(
self.specifications[Subtasks.index("embedding")].classes
)
# use original order for the first run on the shuffled classes list:
emb_task_classes = self.specifications[Subtasks.index("embedding")].classes[:]

sample_idx = 0
embedding_class_idx = 0

while True:
if sample_idx < self.num_dia_samples:
file_id, start_time = self.draw_diarization_chunk(
Expand All @@ -398,15 +407,16 @@ def train__iter__helper(self, rng: random.Random, **filters):
else:
# shuffle embedding classes list and go through this shuffled list
# to make sure to see all the speakers during training
if embedding_class_idx == len(shuffled_embedding_classes):
rng.shuffle(shuffled_embedding_classes)
if embedding_class_idx == len(emb_task_classes):
rng.shuffle(emb_task_classes)
embedding_class_idx = 0
klass = shuffled_embedding_classes[embedding_class_idx]
klass = emb_task_classes[embedding_class_idx]
embedding_class_idx += 1
file_id, start_time = self.draw_embedding_chunk(klass, duration)

sample = self.prepare_chunk(file_id, start_time, duration)
sample_idx = (sample_idx + 1) % self.batch_size
sample_idx = (sample_idx + 1) % batch_size

yield sample

def train__iter__(self):
Expand Down Expand Up @@ -599,15 +609,12 @@ def segmentation_loss(

return seg_loss

def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target):
def compute_diarization_loss(self, prediction, permutated_target):
"""Compute loss for the speaker diarization subtask
Parameters
----------
dia_chunks : torch.Tensor
tensor specifying the chunks assigned to the speaker diarization
task in the current batch. Shape of (batch_size,)
dia_prediction : torch.Tensor
prediction : torch.Tensor
speaker diarization output predicted by the model for the current batch.
Shape of (batch_size, num_spk, num_frames)
permutated_target: torch.Tensor
Expand All @@ -619,12 +626,8 @@ def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target
Permutation-invariant diarization loss
"""

# Get chunks corresponding to the diarization subtask
chunks_prediction = dia_prediction[dia_chunks]
# Get the permutated reference corresponding to diarization subtask
permutated_target_dia = permutated_target[dia_chunks]
# Compute segmentation loss
dia_loss = self.segmentation_loss(chunks_prediction, permutated_target_dia)
dia_loss = self.segmentation_loss(prediction, permutated_target)
self.model.log(
"loss/train/dia",
dia_loss,
Expand All @@ -635,14 +638,11 @@ def compute_diarization_loss(self, dia_chunks, dia_prediction, permutated_target
)
return dia_loss

def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb):
def compute_embedding_loss(self, emb_prediction, target_emb):
"""Compute loss for the speaker embeddings extraction subtask
Parameters
----------
emb_chunks : torch.Tensor
tensor specifying the chunks assigned to the speaker embeddings extraction
task in the current batch. Shape of (batch_size,)
emb_prediction : torch.Tensor
speaker embeddings predicted by the model for the current batch.
Shape of (batch_size * num_spk, embedding_dim)
Expand All @@ -656,9 +656,9 @@ def compute_embedding_loss(self, emb_chunks, emb_prediction, target_emb):
"""

# Get speaker representations from the embedding subtask
embeddings = rearrange(emb_prediction[emb_chunks], "b s e -> (b s) e")
embeddings = rearrange(emb_prediction, "b s e -> (b s) e")
# Get corresponding target label
targets = rearrange(target_emb[emb_chunks], "b s -> (b s)")
targets = rearrange(target_emb, "b s -> (b s)")
# compute loss only on global scope speaker embedding
valid_emb = targets != -1

Expand Down Expand Up @@ -707,6 +707,8 @@ def training_step(self, batch, batch_idx: int):
# drop samples that contain too many speakers
num_speakers: torch.Tensor = torch.sum(torch.any(target_dia, dim=1), dim=1)
keep: torch.Tensor = num_speakers <= self.max_speakers_per_chunk

num_remaining_dia_samples = torch.sum(keep[: self.num_dia_samples])
target_dia = target_dia[keep]
target_emb = target_emb[keep]
waveform = waveform[keep]
Expand All @@ -726,36 +728,29 @@ def training_step(self, batch, batch_idx: int):
torch.arange(target_emb.shape[0]).unsqueeze(1), permut_map
]

# filter out the speaker in the reference that were not found by the diarization
# part of the model, to not compute the embedding loss on these speaker:
# active_spk_mask = torch.any(rearrange(dia_multilabel, "b f s -> b s f"), dim=2)
# (batch_size, num_spk)
# emb_prediction = emb_prediction[active_spk_mask]
# (num_active_spk_found_in_all_the_chunks, emb_size)
# permutated_target_emb = permutated_target_emb[permutated_target_emb != 1]
# (num_activate_spk_found,)

permutated_target_powerset = self.model.powerset.to_powerset(
permutated_target_dia.float()
)
# get embedding and diarization chunks position in current batch
emb_chunks = batch["meta"]["scope"] == 2 # global scope for embedding task
dia_chunks = (
batch["meta"]["scope"] < 2
) # file and database scope for diarization task

dia_prediction = dia_prediction[:num_remaining_dia_samples]
permutated_target_powerset = permutated_target_powerset[
:num_remaining_dia_samples
]

dia_loss = torch.tensor(0)
# if batch contains diarization subtask chunks, then compute diarization loss on these chunks:
if dia_chunks.any():
if self.alpha != 0.0 and torch.any(keep[: self.num_dia_samples]):
dia_loss = self.compute_diarization_loss(
dia_chunks, dia_prediction, permutated_target_powerset
dia_prediction, permutated_target_powerset
)

emb_loss = torch.tensor(0)
# if batch contains embedding subtask chunks, then compute embedding loss on these chunks:
if emb_chunks.any():
if self.alpha != 1.0 and torch.any(keep[self.num_dia_samples :]):
emb_prediction = emb_prediction[num_remaining_dia_samples:]
permutated_target_emb = permutated_target_emb[num_remaining_dia_samples:]
emb_loss = self.compute_embedding_loss(
emb_chunks, emb_prediction, permutated_target_emb
emb_prediction, permutated_target_emb
)

loss = alpha * dia_loss + (1 - alpha) * emb_loss
Expand Down

0 comments on commit d3326b1

Please sign in to comment.