Skip to content

Commit

Permalink
Break taxonomy_model_tf into a separate file, and add classifier extr…
Browse files Browse the repository at this point in the history
…action method.

PiperOrigin-RevId: 681139037
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 1, 2024
1 parent 8816bf2 commit 0938ec6
Show file tree
Hide file tree
Showing 5 changed files with 655 additions and 523 deletions.
27 changes: 12 additions & 15 deletions chirp/inference/scann_search_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from absl import logging
from chirp import audio_utils
from chirp.inference.embed_lib import load_embedding_config
from chirp.inference.tf_examples import get_example_parser
from chirp.projects.zoo import models
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from chirp.projects.zoo import taxonomy_model_tf
from etils import epath
from ml_collections import config_dict
import numpy as np
Expand All @@ -36,11 +36,10 @@ class AudioSearchResult:
"""Results from SCANN search.
Attributes:
index: Index for the searcher ndarray.
distance: The nearest neighbor distance calculated by scann searcher.
filename: The filename of the source audio file.
timestamp_offset_s: Timestamp offset in seconds for the audio file.
index: Index for the searcher ndarray.
distance: The nearest neighbor distance calculated by scann searcher.
filename: The filename of the source audio file.
timestamp_offset_s: Timestamp offset in seconds for the audio file.
"""

index: int
Expand All @@ -54,7 +53,7 @@ def create_searcher(
embeddings_glob: str,
output_dir: str,
num_neighbors: int = 10,
embedding_shape: tuple = (12, 1, 1280),
embedding_shape: tuple[int, ...] = (12, 1, 1280),
distance_measure: str = "squared_l2",
embedding_list_filename="embedding_list.txt",
timestamps_list_filename="timestamps_list.txt",
Expand All @@ -70,7 +69,7 @@ def create_searcher(
shape can be slightly shorter because of the remainder chunk when dividing.
Args:
embedding_glob: Path the directory containing audio embeddings produced by
embeddings_glob: Path the directory containing audio embeddings produced by
the embedding model that matches the embedding_shape.
output_dir: Output directory path to save the scann artifacts.
num_neighbors: Number of neighbors for scann search.
Expand Down Expand Up @@ -112,10 +111,10 @@ def create_searcher(
ds = tf.data.TFRecordDataset(
embeddings_files, num_parallel_reads=tf.data.AUTOTUNE
)
parser = get_example_parser()
parser = tf_examples.get_example_parser()
ds = ds.map(parser)

embedding_config = load_embedding_config(embeddings_glob)
embedding_config = embed_lib.load_embedding_config(embeddings_glob)
hop_size_s = embedding_config.embed_fn_config.model_config.hop_size_s

# These will be saved to output files.
Expand Down Expand Up @@ -173,7 +172,6 @@ def embed_query_audio(
sample_rate: int = 32000,
window_size_s: float = 5.0,
hop_size_s: float = 5.0,
embedding_hidden_dims: int = 1280,
) -> np.ndarray:
"""Embeds the audio query through embedding the model.
Expand All @@ -183,7 +181,6 @@ def embed_query_audio(
sample_rate: Sampling rate for the model.
window_size_s: Window size of the model in seconds.
hop_size_s: Hop size for processing longer audio files.
embedding_hidden_dims: Embedding model's hidden dimension size.
Returns:
Query audio embedding as numpy array.
Expand All @@ -197,7 +194,7 @@ def embed_query_audio(
"window_size_s": window_size_s,
"hop_size_s": hop_size_s,
})
embedding_model = models.TaxonomyModelTF.from_config(config)
embedding_model = taxonomy_model_tf.TaxonomyModelTF.from_config(config)

outputs = embedding_model.embed(np.array(query_audio))

Expand Down
290 changes: 1 addition & 289 deletions chirp/inference/tests/embed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from chirp.inference.search import search
from chirp.models import metrics
from chirp.projects.zoo import models
from chirp.projects.zoo import taxonomy_model_tf
from chirp.projects.zoo import zoo_interface
from chirp.taxonomy import namespace
from etils import epath
Expand Down Expand Up @@ -574,94 +575,6 @@ def test_handcrafted_features(self):
# four summary statistics for each, giving a total of 80 output channels.
self.assertSequenceEqual([5, 1, 80], outputs.embeddings.shape)

def test_sep_embed_wrapper(self):
"""Check that the joint-model wrapper works as intended."""
separator = models.PlaceholderModel(
sample_rate=22050,
make_embeddings=False,
make_logits=False,
make_separated_audio=True,
)

embeddor = models.PlaceholderModel(
sample_rate=22050,
make_embeddings=True,
make_logits=True,
make_separated_audio=False,
)
fake_config = config_dict.ConfigDict()
sep_embed = models.SeparateEmbedModel(
sample_rate=22050,
taxonomy_model_tf_config=fake_config,
separator_model_tf_config=fake_config,
separation_model=separator,
embedding_model=embeddor,
)
audio = np.zeros(5 * 22050, np.float32)

outputs = sep_embed.embed(audio)
# The PlaceholderModel produces one embedding per second, and we have
# five seconds of audio, with two separated channels, plus the channel
# for the raw audio.
# Note that this checks that the sample-rate conversion between the
# separation model and embedding model has worked correctly.
self.assertSequenceEqual(
outputs.embeddings.shape, [5, 3, embeddor.embedding_size]
)
# The Sep+Embed model takes the max logits over the channel dimension.
self.assertSequenceEqual(
outputs.logits['label'].shape, [5, len(embeddor.class_list.classes)]
)

def test_pooled_embeddings(self):
outputs = zoo_interface.InferenceOutputs(
embeddings=np.zeros([10, 2, 8]), batched=False
)
batched_outputs = zoo_interface.InferenceOutputs(
embeddings=np.zeros([3, 10, 2, 8]), batched=True
)

# Check that no-op is no-op.
non_pooled = outputs.pooled_embeddings('', '')
self.assertSequenceEqual(non_pooled.shape, outputs.embeddings.shape)
batched_non_pooled = batched_outputs.pooled_embeddings('', '')
self.assertSequenceEqual(
batched_non_pooled.shape, batched_outputs.embeddings.shape
)

for pooling_method in zoo_interface.POOLING_METHODS:
if pooling_method == 'squeeze':
# The 'squeeze' pooling method throws an exception if axis size is > 1.
with self.assertRaises(ValueError):
outputs.pooled_embeddings(pooling_method, '')
continue
elif pooling_method == 'flatten':
# Concatenates over the target axis.
time_pooled = outputs.pooled_embeddings(pooling_method, '')
self.assertSequenceEqual(time_pooled.shape, [2, 80])
continue

time_pooled = outputs.pooled_embeddings(pooling_method, '')
self.assertSequenceEqual(time_pooled.shape, [2, 8])
batched_time_pooled = batched_outputs.pooled_embeddings(
pooling_method, ''
)
self.assertSequenceEqual(batched_time_pooled.shape, [3, 2, 8])

channel_pooled = outputs.pooled_embeddings('', pooling_method)
self.assertSequenceEqual(channel_pooled.shape, [10, 8])
batched_channel_pooled = batched_outputs.pooled_embeddings(
'', pooling_method
)
self.assertSequenceEqual(batched_channel_pooled.shape, [3, 10, 8])

both_pooled = outputs.pooled_embeddings(pooling_method, pooling_method)
self.assertSequenceEqual(both_pooled.shape, [8])
batched_both_pooled = batched_outputs.pooled_embeddings(
pooling_method, pooling_method
)
self.assertSequenceEqual(batched_both_pooled.shape, [3, 8])

def test_beam_pipeline(self):
"""Check that we can write embeddings to TFRecord file."""
test_wav_path = os.fspath(
Expand Down Expand Up @@ -700,207 +613,6 @@ def test_beam_pipeline(self):

print(metrics)

@parameterized.product(
model_return_type=('tuple', 'dict'),
batchable=(True, False),
)
def test_taxonomy_model_tf(self, model_return_type, batchable):
class FakeModelFn:
output_depths = {'label': 3, 'embedding': 256}

def infer_tf(self, audio_array):
outputs = {
k: np.zeros([audio_array.shape[0], d], dtype=np.float32)
for k, d in self.output_depths.items()
}
if model_return_type == 'tuple':
# Published Perch models v1 through v4 returned a tuple, not a dict.
return outputs['label'], outputs['embedding']
return outputs

class_list = {
'label': namespace.ClassList('fake', ['alpha', 'beta', 'delta'])
}
wrapped_model = models.TaxonomyModelTF(
sample_rate=32000,
model_path='/dev/null',
window_size_s=5.0,
hop_size_s=5.0,
model=FakeModelFn(),
class_list=class_list,
batchable=batchable,
)

# Check that a single frame of audio is handled properly.
outputs = wrapped_model.embed(np.zeros([5 * 32000], dtype=np.float32))
self.assertFalse(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [1, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [1, 3])

# Check that multi-frame audio is handled properly.
outputs = wrapped_model.embed(np.zeros([20 * 32000], dtype=np.float32))
self.assertFalse(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [4, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [4, 3])

# Check that a batch of single frame of audio is handled properly.
outputs = wrapped_model.batch_embed(
np.zeros([10, 5 * 32000], dtype=np.float32)
)
self.assertTrue(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [10, 1, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [10, 1, 3])

# Check that a batch of multi-frame audio is handled properly.
outputs = wrapped_model.batch_embed(
np.zeros([2, 20 * 32000], dtype=np.float32)
)
self.assertTrue(outputs.batched)
self.assertSequenceEqual(outputs.embeddings.shape, [2, 4, 1, 256])
self.assertSequenceEqual(outputs.logits['label'].shape, [2, 4, 3])

def test_whale_model(self):
# prereq
class FakeModel(tf_keras.Model):
"""Fake implementation of the humpback_whale SavedModel API.
The use of `tf_keras` as opposed to `tf.keras` is intentional; the models
this fakes were exported using "the pure-TensorFlow implementation of
Keras."
"""

def __init__(self):
super().__init__()
self._sample_rate = 10000
self._classes = ['Mn']
self._embedder = tf_keras.layers.Dense(32)
self._classifier = tf_keras.layers.Dense(len(self._classes))

def call(self, spectrograms, training=False):
logits = self.logits(spectrograms)
return tf.nn.sigmoid(logits)

@tf.function(
input_signature=[tf.TensorSpec([None, None, 1], tf.dtypes.float32)]
)
def front_end(self, waveform):
return tf.math.abs(
tf.signal.stft(
tf.squeeze(waveform, -1),
frame_length=1024,
frame_step=300,
fft_length=128,
)[..., 1:]
)

@tf.function(
input_signature=[tf.TensorSpec([None, 128, 64], tf.dtypes.float32)]
)
def features(self, spectrogram):
return self._embedder(tf.math.reduce_mean(spectrogram, axis=-2))

@tf.function(
input_signature=[tf.TensorSpec([None, 128, 64], tf.dtypes.float32)]
)
def logits(self, spectrogram):
features = self.features(spectrogram)
return self._classifier(features)

@tf.function(
input_signature=[
tf.TensorSpec([None, None, 1], tf.dtypes.float32),
tf.TensorSpec([], tf.dtypes.int64),
]
)
def score(self, waveform, context_step_samples):
spectrogram = self.front_end(waveform)
windows = tf.signal.frame(
spectrogram, frame_length=128, frame_step=128, axis=1
)
shape = tf.shape(windows)
batch_size = shape[0]
num_windows = shape[1]
frame_length = shape[2]
tf.debugging.assert_equal(frame_length, 128)
channels_len = shape[3]
logits = self.logits(
tf.reshape(
windows, (batch_size * num_windows, frame_length, channels_len)
)
)
return {'score': tf.nn.sigmoid(logits)}

@tf.function(input_signature=[])
def metadata(self):
return {
'input_sample_rate': tf.constant(
self._sample_rate, tf.dtypes.int64
),
'context_width_samples': tf.constant(39124, tf.dtypes.int64),
'class_names': tf.constant(self._classes),
}

# setup
fake_model = FakeModel()
batch_size = 2
duration_seconds = 10
sample_rate = fake_model.metadata()['input_sample_rate']
waveform = np.random.randn(
batch_size,
sample_rate * duration_seconds,
)
expected_frames = int(10 / 3.9124) + 1
# Call the model to avoid "forward pass of the model is not defined" on
# save.
spectrograms = fake_model.front_end(waveform[:, :, np.newaxis])
fake_model(spectrograms[:, :128, :])
model_path = os.path.join(tempfile.gettempdir(), 'whale_model')
fake_model.save(
model_path,
signatures={
'score': fake_model.score,
'metadata': fake_model.metadata,
'serving_default': fake_model.score,
'front_end': fake_model.front_end,
'features': fake_model.features,
'logits': fake_model.logits,
},
)

with self.subTest('from_url'):
# invoke
model = models.GoogleWhaleModel.load_humpback_model(model_path)
outputs = model.batch_embed(waveform)

# verify
self.assertTrue(outputs.batched)
self.assertSequenceEqual(
outputs.embeddings.shape, [batch_size, expected_frames, 1, 32]
)
self.assertSequenceEqual(
outputs.logits['humpback'].shape, [batch_size, expected_frames, 1]
)

with self.subTest('from_config'):
# invoke
config = config_dict.ConfigDict()
config.model_url = model_path
config.sample_rate = float(sample_rate)
config.window_size_s = 3.9124
config.peak_norm = 0.02
model = models.GoogleWhaleModel.from_config(config)
# Let's check the regular embed this time.
outputs = model.embed(waveform[0])

# verify
self.assertFalse(outputs.batched)
self.assertSequenceEqual(
outputs.embeddings.shape, [expected_frames, 1, 32]
)
self.assertSequenceEqual(
outputs.logits['multispecies_whale'].shape, [expected_frames, 1]
)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 0938ec6

Please sign in to comment.