From 0938ec683164c5439d971b7ed10e1394b0345893 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Tue, 1 Oct 2024 13:15:25 -0700 Subject: [PATCH] Break taxonomy_model_tf into a separate file, and add classifier extraction method. PiperOrigin-RevId: 681139037 --- chirp/inference/scann_search_lib.py | 27 +- chirp/inference/tests/embed_test.py | 290 +------------------- chirp/projects/zoo/models.py | 224 +--------------- chirp/projects/zoo/taxonomy_model_tf.py | 300 +++++++++++++++++++++ chirp/projects/zoo/zoo_test.py | 337 ++++++++++++++++++++++++ 5 files changed, 655 insertions(+), 523 deletions(-) create mode 100644 chirp/projects/zoo/taxonomy_model_tf.py create mode 100644 chirp/projects/zoo/zoo_test.py diff --git a/chirp/inference/scann_search_lib.py b/chirp/inference/scann_search_lib.py index 37f9ab0f..adcf5017 100644 --- a/chirp/inference/scann_search_lib.py +++ b/chirp/inference/scann_search_lib.py @@ -20,9 +20,9 @@ from absl import logging from chirp import audio_utils -from chirp.inference.embed_lib import load_embedding_config -from chirp.inference.tf_examples import get_example_parser -from chirp.projects.zoo import models +from chirp.inference import embed_lib +from chirp.inference import tf_examples +from chirp.projects.zoo import taxonomy_model_tf from etils import epath from ml_collections import config_dict import numpy as np @@ -36,11 +36,10 @@ class AudioSearchResult: """Results from SCANN search. Attributes: - - index: Index for the searcher ndarray. - distance: The nearest neighbor distance calculated by scann searcher. - filename: The filename of the source audio file. - timestamp_offset_s: Timestamp offset in seconds for the audio file. + index: Index for the searcher ndarray. + distance: The nearest neighbor distance calculated by scann searcher. + filename: The filename of the source audio file. + timestamp_offset_s: Timestamp offset in seconds for the audio file. """ index: int @@ -54,7 +53,7 @@ def create_searcher( embeddings_glob: str, output_dir: str, num_neighbors: int = 10, - embedding_shape: tuple = (12, 1, 1280), + embedding_shape: tuple[int, ...] = (12, 1, 1280), distance_measure: str = "squared_l2", embedding_list_filename="embedding_list.txt", timestamps_list_filename="timestamps_list.txt", @@ -70,7 +69,7 @@ def create_searcher( shape can be slightly shorter because of the remainder chunk when dividing. Args: - embedding_glob: Path the directory containing audio embeddings produced by + embeddings_glob: Path the directory containing audio embeddings produced by the embedding model that matches the embedding_shape. output_dir: Output directory path to save the scann artifacts. num_neighbors: Number of neighbors for scann search. @@ -112,10 +111,10 @@ def create_searcher( ds = tf.data.TFRecordDataset( embeddings_files, num_parallel_reads=tf.data.AUTOTUNE ) - parser = get_example_parser() + parser = tf_examples.get_example_parser() ds = ds.map(parser) - embedding_config = load_embedding_config(embeddings_glob) + embedding_config = embed_lib.load_embedding_config(embeddings_glob) hop_size_s = embedding_config.embed_fn_config.model_config.hop_size_s # These will be saved to output files. @@ -173,7 +172,6 @@ def embed_query_audio( sample_rate: int = 32000, window_size_s: float = 5.0, hop_size_s: float = 5.0, - embedding_hidden_dims: int = 1280, ) -> np.ndarray: """Embeds the audio query through embedding the model. @@ -183,7 +181,6 @@ def embed_query_audio( sample_rate: Sampling rate for the model. window_size_s: Window size of the model in seconds. hop_size_s: Hop size for processing longer audio files. - embedding_hidden_dims: Embedding model's hidden dimension size. Returns: Query audio embedding as numpy array. @@ -197,7 +194,7 @@ def embed_query_audio( "window_size_s": window_size_s, "hop_size_s": hop_size_s, }) - embedding_model = models.TaxonomyModelTF.from_config(config) + embedding_model = taxonomy_model_tf.TaxonomyModelTF.from_config(config) outputs = embedding_model.embed(np.array(query_audio)) diff --git a/chirp/inference/tests/embed_test.py b/chirp/inference/tests/embed_test.py index 5a940c1f..6e40c7ed 100644 --- a/chirp/inference/tests/embed_test.py +++ b/chirp/inference/tests/embed_test.py @@ -33,6 +33,7 @@ from chirp.inference.search import search from chirp.models import metrics from chirp.projects.zoo import models +from chirp.projects.zoo import taxonomy_model_tf from chirp.projects.zoo import zoo_interface from chirp.taxonomy import namespace from etils import epath @@ -574,94 +575,6 @@ def test_handcrafted_features(self): # four summary statistics for each, giving a total of 80 output channels. self.assertSequenceEqual([5, 1, 80], outputs.embeddings.shape) - def test_sep_embed_wrapper(self): - """Check that the joint-model wrapper works as intended.""" - separator = models.PlaceholderModel( - sample_rate=22050, - make_embeddings=False, - make_logits=False, - make_separated_audio=True, - ) - - embeddor = models.PlaceholderModel( - sample_rate=22050, - make_embeddings=True, - make_logits=True, - make_separated_audio=False, - ) - fake_config = config_dict.ConfigDict() - sep_embed = models.SeparateEmbedModel( - sample_rate=22050, - taxonomy_model_tf_config=fake_config, - separator_model_tf_config=fake_config, - separation_model=separator, - embedding_model=embeddor, - ) - audio = np.zeros(5 * 22050, np.float32) - - outputs = sep_embed.embed(audio) - # The PlaceholderModel produces one embedding per second, and we have - # five seconds of audio, with two separated channels, plus the channel - # for the raw audio. - # Note that this checks that the sample-rate conversion between the - # separation model and embedding model has worked correctly. - self.assertSequenceEqual( - outputs.embeddings.shape, [5, 3, embeddor.embedding_size] - ) - # The Sep+Embed model takes the max logits over the channel dimension. - self.assertSequenceEqual( - outputs.logits['label'].shape, [5, len(embeddor.class_list.classes)] - ) - - def test_pooled_embeddings(self): - outputs = zoo_interface.InferenceOutputs( - embeddings=np.zeros([10, 2, 8]), batched=False - ) - batched_outputs = zoo_interface.InferenceOutputs( - embeddings=np.zeros([3, 10, 2, 8]), batched=True - ) - - # Check that no-op is no-op. - non_pooled = outputs.pooled_embeddings('', '') - self.assertSequenceEqual(non_pooled.shape, outputs.embeddings.shape) - batched_non_pooled = batched_outputs.pooled_embeddings('', '') - self.assertSequenceEqual( - batched_non_pooled.shape, batched_outputs.embeddings.shape - ) - - for pooling_method in zoo_interface.POOLING_METHODS: - if pooling_method == 'squeeze': - # The 'squeeze' pooling method throws an exception if axis size is > 1. - with self.assertRaises(ValueError): - outputs.pooled_embeddings(pooling_method, '') - continue - elif pooling_method == 'flatten': - # Concatenates over the target axis. - time_pooled = outputs.pooled_embeddings(pooling_method, '') - self.assertSequenceEqual(time_pooled.shape, [2, 80]) - continue - - time_pooled = outputs.pooled_embeddings(pooling_method, '') - self.assertSequenceEqual(time_pooled.shape, [2, 8]) - batched_time_pooled = batched_outputs.pooled_embeddings( - pooling_method, '' - ) - self.assertSequenceEqual(batched_time_pooled.shape, [3, 2, 8]) - - channel_pooled = outputs.pooled_embeddings('', pooling_method) - self.assertSequenceEqual(channel_pooled.shape, [10, 8]) - batched_channel_pooled = batched_outputs.pooled_embeddings( - '', pooling_method - ) - self.assertSequenceEqual(batched_channel_pooled.shape, [3, 10, 8]) - - both_pooled = outputs.pooled_embeddings(pooling_method, pooling_method) - self.assertSequenceEqual(both_pooled.shape, [8]) - batched_both_pooled = batched_outputs.pooled_embeddings( - pooling_method, pooling_method - ) - self.assertSequenceEqual(batched_both_pooled.shape, [3, 8]) - def test_beam_pipeline(self): """Check that we can write embeddings to TFRecord file.""" test_wav_path = os.fspath( @@ -700,207 +613,6 @@ def test_beam_pipeline(self): print(metrics) - @parameterized.product( - model_return_type=('tuple', 'dict'), - batchable=(True, False), - ) - def test_taxonomy_model_tf(self, model_return_type, batchable): - class FakeModelFn: - output_depths = {'label': 3, 'embedding': 256} - - def infer_tf(self, audio_array): - outputs = { - k: np.zeros([audio_array.shape[0], d], dtype=np.float32) - for k, d in self.output_depths.items() - } - if model_return_type == 'tuple': - # Published Perch models v1 through v4 returned a tuple, not a dict. - return outputs['label'], outputs['embedding'] - return outputs - - class_list = { - 'label': namespace.ClassList('fake', ['alpha', 'beta', 'delta']) - } - wrapped_model = models.TaxonomyModelTF( - sample_rate=32000, - model_path='/dev/null', - window_size_s=5.0, - hop_size_s=5.0, - model=FakeModelFn(), - class_list=class_list, - batchable=batchable, - ) - - # Check that a single frame of audio is handled properly. - outputs = wrapped_model.embed(np.zeros([5 * 32000], dtype=np.float32)) - self.assertFalse(outputs.batched) - self.assertSequenceEqual(outputs.embeddings.shape, [1, 1, 256]) - self.assertSequenceEqual(outputs.logits['label'].shape, [1, 3]) - - # Check that multi-frame audio is handled properly. - outputs = wrapped_model.embed(np.zeros([20 * 32000], dtype=np.float32)) - self.assertFalse(outputs.batched) - self.assertSequenceEqual(outputs.embeddings.shape, [4, 1, 256]) - self.assertSequenceEqual(outputs.logits['label'].shape, [4, 3]) - - # Check that a batch of single frame of audio is handled properly. - outputs = wrapped_model.batch_embed( - np.zeros([10, 5 * 32000], dtype=np.float32) - ) - self.assertTrue(outputs.batched) - self.assertSequenceEqual(outputs.embeddings.shape, [10, 1, 1, 256]) - self.assertSequenceEqual(outputs.logits['label'].shape, [10, 1, 3]) - - # Check that a batch of multi-frame audio is handled properly. - outputs = wrapped_model.batch_embed( - np.zeros([2, 20 * 32000], dtype=np.float32) - ) - self.assertTrue(outputs.batched) - self.assertSequenceEqual(outputs.embeddings.shape, [2, 4, 1, 256]) - self.assertSequenceEqual(outputs.logits['label'].shape, [2, 4, 3]) - - def test_whale_model(self): - # prereq - class FakeModel(tf_keras.Model): - """Fake implementation of the humpback_whale SavedModel API. - - The use of `tf_keras` as opposed to `tf.keras` is intentional; the models - this fakes were exported using "the pure-TensorFlow implementation of - Keras." - """ - - def __init__(self): - super().__init__() - self._sample_rate = 10000 - self._classes = ['Mn'] - self._embedder = tf_keras.layers.Dense(32) - self._classifier = tf_keras.layers.Dense(len(self._classes)) - - def call(self, spectrograms, training=False): - logits = self.logits(spectrograms) - return tf.nn.sigmoid(logits) - - @tf.function( - input_signature=[tf.TensorSpec([None, None, 1], tf.dtypes.float32)] - ) - def front_end(self, waveform): - return tf.math.abs( - tf.signal.stft( - tf.squeeze(waveform, -1), - frame_length=1024, - frame_step=300, - fft_length=128, - )[..., 1:] - ) - - @tf.function( - input_signature=[tf.TensorSpec([None, 128, 64], tf.dtypes.float32)] - ) - def features(self, spectrogram): - return self._embedder(tf.math.reduce_mean(spectrogram, axis=-2)) - - @tf.function( - input_signature=[tf.TensorSpec([None, 128, 64], tf.dtypes.float32)] - ) - def logits(self, spectrogram): - features = self.features(spectrogram) - return self._classifier(features) - - @tf.function( - input_signature=[ - tf.TensorSpec([None, None, 1], tf.dtypes.float32), - tf.TensorSpec([], tf.dtypes.int64), - ] - ) - def score(self, waveform, context_step_samples): - spectrogram = self.front_end(waveform) - windows = tf.signal.frame( - spectrogram, frame_length=128, frame_step=128, axis=1 - ) - shape = tf.shape(windows) - batch_size = shape[0] - num_windows = shape[1] - frame_length = shape[2] - tf.debugging.assert_equal(frame_length, 128) - channels_len = shape[3] - logits = self.logits( - tf.reshape( - windows, (batch_size * num_windows, frame_length, channels_len) - ) - ) - return {'score': tf.nn.sigmoid(logits)} - - @tf.function(input_signature=[]) - def metadata(self): - return { - 'input_sample_rate': tf.constant( - self._sample_rate, tf.dtypes.int64 - ), - 'context_width_samples': tf.constant(39124, tf.dtypes.int64), - 'class_names': tf.constant(self._classes), - } - - # setup - fake_model = FakeModel() - batch_size = 2 - duration_seconds = 10 - sample_rate = fake_model.metadata()['input_sample_rate'] - waveform = np.random.randn( - batch_size, - sample_rate * duration_seconds, - ) - expected_frames = int(10 / 3.9124) + 1 - # Call the model to avoid "forward pass of the model is not defined" on - # save. - spectrograms = fake_model.front_end(waveform[:, :, np.newaxis]) - fake_model(spectrograms[:, :128, :]) - model_path = os.path.join(tempfile.gettempdir(), 'whale_model') - fake_model.save( - model_path, - signatures={ - 'score': fake_model.score, - 'metadata': fake_model.metadata, - 'serving_default': fake_model.score, - 'front_end': fake_model.front_end, - 'features': fake_model.features, - 'logits': fake_model.logits, - }, - ) - - with self.subTest('from_url'): - # invoke - model = models.GoogleWhaleModel.load_humpback_model(model_path) - outputs = model.batch_embed(waveform) - - # verify - self.assertTrue(outputs.batched) - self.assertSequenceEqual( - outputs.embeddings.shape, [batch_size, expected_frames, 1, 32] - ) - self.assertSequenceEqual( - outputs.logits['humpback'].shape, [batch_size, expected_frames, 1] - ) - - with self.subTest('from_config'): - # invoke - config = config_dict.ConfigDict() - config.model_url = model_path - config.sample_rate = float(sample_rate) - config.window_size_s = 3.9124 - config.peak_norm = 0.02 - model = models.GoogleWhaleModel.from_config(config) - # Let's check the regular embed this time. - outputs = model.embed(waveform[0]) - - # verify - self.assertFalse(outputs.batched) - self.assertSequenceEqual( - outputs.embeddings.shape, [expected_frames, 1, 32] - ) - self.assertSequenceEqual( - outputs.logits['multispecies_whale'].shape, [expected_frames, 1] - ) - if __name__ == '__main__': absltest.main() diff --git a/chirp/projects/zoo/models.py b/chirp/projects/zoo/models.py index 057fc066..70335786 100644 --- a/chirp/projects/zoo/models.py +++ b/chirp/projects/zoo/models.py @@ -22,6 +22,7 @@ from absl import logging from chirp.models import frontend from chirp.models import handcrafted_features +from chirp.projects.zoo import taxonomy_model_tf from chirp.projects.zoo import zoo_interface from chirp.taxonomy import namespace from chirp.taxonomy import namespace_db @@ -32,21 +33,11 @@ import tensorflow.compat.v1 as tf1 import tensorflow_hub as hub -PERCH_TF_HUB_URL = ( - 'https://www.kaggle.com/models/google/' - 'bird-vocalization-classifier/frameworks/TensorFlow2/' - 'variations/bird-vocalization-classifier/versions' -) - -SURFPERCH_TF_HUB_URL = ( - 'https://www.kaggle.com/models/google/surfperch/TensorFlow2/1' -) - def model_class_map() -> dict[str, Any]: """Get the mapping of model keys to classes.""" return { - 'taxonomy_model_tf': TaxonomyModelTF, + 'taxonomy_model_tf': taxonomy_model_tf.TaxonomyModelTF, 'separator_model_tf': SeparatorModelTF, 'birb_separator_model_tf1': BirbSepModelTF1, 'birdnet': BirdNet, @@ -90,7 +81,7 @@ def get_preset_model_config(preset_name): model_config.hop_size_s = 5.0 model_config.sample_rate = 32000 model_config.tfhub_version = 1 - model_config.tfhub_path = SURFPERCH_TF_HUB_URL + model_config.tfhub_path = taxonomy_model_tf.SURFPERCH_TF_HUB_URL model_config.model_path = '' elif preset_name.startswith('birdnet'): model_key = 'birdnet' @@ -140,7 +131,7 @@ class SeparateEmbedModel(zoo_interface.EmbeddingModel): separator_model_tf_config: config_dict.ConfigDict taxonomy_model_tf_config: config_dict.ConfigDict separation_model: 'SeparatorModelTF' - embedding_model: 'TaxonomyModelTF' + embedding_model: taxonomy_model_tf.TaxonomyModelTF embed_raw: bool = True @classmethod @@ -148,7 +139,7 @@ def from_config(cls, config: config_dict.ConfigDict) -> 'SeparateEmbedModel': separation_model = SeparatorModelTF.from_config( config.separator_model_tf_config ) - embedding_model = TaxonomyModelTF.from_config( + embedding_model = taxonomy_model_tf.TaxonomyModelTF.from_config( config.taxonomy_model_tf_config ) return cls( @@ -312,211 +303,6 @@ def batch_embed( return zoo_interface.batch_embed_from_embed_fn(self.embed, audio_batch) -@dataclasses.dataclass -class TaxonomyModelTF(zoo_interface.EmbeddingModel): - """Taxonomy SavedModel. - - Attributes: - model_path: Path to model files. - window_size_s: Window size for framing audio in seconds. TODO(tomdenton): - Ideally this should come from a model metadata file. - hop_size_s: Hop size for inference. - model: Loaded TF SavedModel. - class_list: Loaded class_list for the model's output logits. - batchable: Whether the model supports batched input. - target_peak: Peak normalization value. - """ - - model_path: str - window_size_s: float - hop_size_s: float - model: Any # TF SavedModel - class_list: dict[str, namespace.ClassList] - batchable: bool - target_peak: float | None = 0.25 - tfhub_version: int | None = None - - @classmethod - def is_batchable(cls, model: Any) -> bool: - sig = model.signatures['serving_default'] - return sig.inputs[0].shape[0] is None - - @classmethod - def load_class_lists(cls, csv_glob): - class_lists = {} - for csv_path in csv_glob: - with csv_path.open('r') as f: - key = csv_path.name.replace('.csv', '') - class_lists[key] = namespace.ClassList.from_csv(f) - return class_lists - - @classmethod - def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': - if not hasattr(config, 'tfhub_version') or config.tfhub_version is None: - raise ValueError('tfhub_version is required to load from TFHub.') - if config.model_path: - raise ValueError( - 'Exactly one of tfhub_version and model_path should be set.' - ) - if hasattr(config, 'tfhub_path'): - tfhub_path = config.tfhub_path - del config.tfhub_path - else: - tfhub_path = PERCH_TF_HUB_URL - - if tfhub_path == PERCH_TF_HUB_URL and config.tfhub_version in (5, 6, 7): - # Due to SNAFUs uploading the new model version to KaggleModels, - # some version numbers were skipped. - raise ValueError('TFHub version 5, 6, and 7 do not exist.') - - model_url = f'{tfhub_path}/{config.tfhub_version}' - # This model behaves exactly like the usual saved_model. - model = hub.load(model_url) - - # Check whether the model support polymorphic batch shape. - batchable = cls.is_batchable(model) - - # Get the labels CSV from TFHub. - model_path = hub.resolve(model_url) - class_lists_glob = (epath.Path(model_path) / 'assets').glob('*.csv') - class_lists = cls.load_class_lists(class_lists_glob) - return cls( - model=model, class_list=class_lists, batchable=batchable, **config - ) - - @classmethod - def load_version( - cls, tfhub_version: int, hop_size_s: float = 5.0 - ) -> 'TaxonomyModelTF': - cfg = config_dict.ConfigDict({ - 'model_path': '', - 'sample_rate': 32000, - 'window_size_s': 5.0, - 'hop_size_s': hop_size_s, - 'target_peak': 0.25, - 'tfhub_version': tfhub_version, - }) - return cls.from_tfhub(cfg) - - @classmethod - def load_surfperch_version( - cls, tfhub_version: int, hop_size_s: float = 5.0 - ) -> 'TaxonomyModelTF': - """Load a model from TFHub.""" - cfg = config_dict.ConfigDict({ - 'model_path': '', - 'sample_rate': 32000, - 'window_size_s': 5.0, - 'hop_size_s': hop_size_s, - 'target_peak': 0.25, - 'tfhub_version': tfhub_version, - 'tfhub_path': SURFPERCH_TF_HUB_URL, - }) - return cls.from_tfhub(cfg) - - @classmethod - def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': - logging.info('Loading taxonomy model...') - - if hasattr(config, 'tfhub_version') and config.tfhub_version is not None: - return cls.from_tfhub(config) - - base_path = epath.Path(config.model_path) - if (base_path / 'saved_model.pb').exists() and ( - base_path / 'assets' - ).exists(): - # This looks like a downloaded TFHub model. - model_path = base_path - class_lists_glob = (epath.Path(model_path) / 'assets').glob('*.csv') - else: - # Probably a savedmodel distributed directly. - model_path = base_path / 'savedmodel' - class_lists_glob = epath.Path(base_path).glob('*.csv') - - model = tf.saved_model.load(model_path) - class_lists = cls.load_class_lists(class_lists_glob) - - # Check whether the model support polymorphic batch shape. - batchable = cls.is_batchable(model) - return cls( - model=model, class_list=class_lists, batchable=batchable, **config - ) - - def embed(self, audio_array: np.ndarray) -> zoo_interface.InferenceOutputs: - return zoo_interface.embed_from_batch_embed_fn( - self.batch_embed, audio_array - ) - - def _nonbatchable_batch_embed(self, audio_batch: np.ndarray): - """Embed a batch of audio with an old non-batchable model.""" - all_logits = [] - all_embeddings = [] - for audio in audio_batch: - outputs = self.model.infer_tf(audio[np.newaxis, :]) - if hasattr(outputs, 'keys'): - embedding = outputs.pop('embedding') - logits = outputs.pop('label') - else: - # Assume the model output is always a (logits, embedding) twople. - logits, embedding = outputs - all_logits.append(logits) - all_embeddings.append(embedding) - all_logits = np.stack(all_logits, axis=0) - all_embeddings = np.stack(all_embeddings, axis=0) - return { - 'embedding': all_embeddings, - 'label': all_logits, - } - - def batch_embed( - self, audio_batch: np.ndarray[Any, Any] - ) -> zoo_interface.InferenceOutputs: - framed_audio = self.frame_audio( - audio_batch, self.window_size_s, self.hop_size_s - ) - framed_audio = self.normalize_audio(framed_audio, self.target_peak) - rebatched_audio = framed_audio.reshape([-1, framed_audio.shape[-1]]) - - if not self.batchable: - outputs = self._nonbatchable_batch_embed(rebatched_audio) - else: - outputs = self.model.infer_tf(rebatched_audio) - - frontend_output = None - if hasattr(outputs, 'keys'): - # Dictionary-type outputs. Arrange appropriately. - embeddings = outputs.pop('embedding') - if 'frontend' in outputs: - frontend_output = outputs.pop('frontend') - # Assume remaining outputs are all logits. - logits = outputs - elif len(outputs) == 2: - # Assume logits, embeddings outputs. - label_logits, embeddings = outputs - logits = {'label': label_logits} - else: - raise ValueError('Unexpected outputs type.') - - for k, v in logits.items(): - logits[k] = np.reshape(v, framed_audio.shape[:2] + (v.shape[-1],)) - # Unbatch and add channel dimension. - embeddings = np.reshape( - embeddings, - framed_audio.shape[:2] - + ( - 1, - embeddings.shape[-1], - ), - ) - return zoo_interface.InferenceOutputs( - embeddings=embeddings, - logits=logits, - separated_audio=None, - batched=True, - frontend=frontend_output, - ) - - @dataclasses.dataclass class SeparatorModelTF(zoo_interface.EmbeddingModel): """Separator SavedModel. diff --git a/chirp/projects/zoo/taxonomy_model_tf.py b/chirp/projects/zoo/taxonomy_model_tf.py new file mode 100644 index 00000000..da9e2fcb --- /dev/null +++ b/chirp/projects/zoo/taxonomy_model_tf.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Perch Taxonomy Model.""" + +import dataclasses +from typing import Any + +from absl import logging +from chirp.projects.zoo import zoo_interface +from chirp.taxonomy import namespace +from etils import epath +from ml_collections import config_dict +import numpy as np +import tensorflow as tf +import tensorflow_hub as hub + + +PERCH_TF_HUB_URL = ( + 'https://www.kaggle.com/models/google/' + 'bird-vocalization-classifier/frameworks/TensorFlow2/' + 'variations/bird-vocalization-classifier/versions' +) + +SURFPERCH_TF_HUB_URL = ( + 'https://www.kaggle.com/models/google/surfperch/TensorFlow2/1' +) + + +@dataclasses.dataclass +class TaxonomyModelTF(zoo_interface.EmbeddingModel): + """Taxonomy SavedModel. + + Attributes: + model_path: Path to model files. + window_size_s: Window size for framing audio in seconds. TODO(tomdenton): + Ideally this should come from a model metadata file. + hop_size_s: Hop size for inference. + model: Loaded TF SavedModel. + class_list: Loaded class_list for the model's output logits. + batchable: Whether the model supports batched input. + target_peak: Peak normalization value. + """ + + model_path: str + window_size_s: float + hop_size_s: float + model: Any # TF SavedModel + class_list: dict[str, namespace.ClassList] + batchable: bool + target_peak: float | None = 0.25 + tfhub_version: int | None = None + + @classmethod + def is_batchable(cls, model: Any) -> bool: + sig = model.signatures['serving_default'] + return sig.inputs[0].shape[0] is None + + @classmethod + def load_class_lists(cls, csv_glob): + class_lists = {} + for csv_path in csv_glob: + with csv_path.open('r') as f: + key = csv_path.name.replace('.csv', '') + class_lists[key] = namespace.ClassList.from_csv(f) + return class_lists + + @classmethod + def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': + if not hasattr(config, 'tfhub_version') or config.tfhub_version is None: + raise ValueError('tfhub_version is required to load from TFHub.') + if config.model_path: + raise ValueError( + 'Exactly one of tfhub_version and model_path should be set.' + ) + if hasattr(config, 'tfhub_path'): + tfhub_path = config.tfhub_path + del config.tfhub_path + else: + tfhub_path = PERCH_TF_HUB_URL + + if tfhub_path == PERCH_TF_HUB_URL and config.tfhub_version in (5, 6, 7): + # Due to SNAFUs uploading the new model version to KaggleModels, + # some version numbers were skipped. + raise ValueError('TFHub version 5, 6, and 7 do not exist.') + + model_url = f'{tfhub_path}/{config.tfhub_version}' + # This model behaves exactly like the usual saved_model. + model = hub.load(model_url) + + # Check whether the model support polymorphic batch shape. + batchable = cls.is_batchable(model) + + # Get the labels CSV from TFHub. + model_path = hub.resolve(model_url) + config.model_path = model_path + class_lists_glob = (epath.Path(model_path) / 'assets').glob('*.csv') + class_lists = cls.load_class_lists(class_lists_glob) + return cls( + model=model, + class_list=class_lists, + batchable=batchable, + **config, + ) + + @classmethod + def load_version( + cls, tfhub_version: int, hop_size_s: float = 5.0 + ) -> 'TaxonomyModelTF': + cfg = config_dict.ConfigDict({ + 'model_path': '', + 'sample_rate': 32000, + 'window_size_s': 5.0, + 'hop_size_s': hop_size_s, + 'target_peak': 0.25, + 'tfhub_version': tfhub_version, + }) + return cls.from_tfhub(cfg) + + @classmethod + def load_surfperch_version( + cls, tfhub_version: int, hop_size_s: float = 5.0 + ) -> 'TaxonomyModelTF': + """Load a model from TFHub.""" + cfg = config_dict.ConfigDict({ + 'model_path': '', + 'sample_rate': 32000, + 'window_size_s': 5.0, + 'hop_size_s': hop_size_s, + 'target_peak': 0.25, + 'tfhub_version': tfhub_version, + 'tfhub_path': SURFPERCH_TF_HUB_URL, + }) + return cls.from_tfhub(cfg) + + @classmethod + def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': + logging.info('Loading taxonomy model...') + + if hasattr(config, 'tfhub_version') and config.tfhub_version is not None: + return cls.from_tfhub(config) + + base_path = epath.Path(config.model_path) + if (base_path / 'saved_model.pb').exists() and ( + base_path / 'assets' + ).exists(): + # This looks like a downloaded TFHub model. + model_path = base_path + class_lists_glob = (epath.Path(model_path) / 'assets').glob('*.csv') + else: + # Probably a savedmodel distributed directly. + model_path = base_path / 'savedmodel' + class_lists_glob = epath.Path(base_path).glob('*.csv') + + model = tf.saved_model.load(model_path) + class_lists = cls.load_class_lists(class_lists_glob) + + # Check whether the model support polymorphic batch shape. + batchable = cls.is_batchable(model) + return cls( + model=model, class_list=class_lists, batchable=batchable, **config + ) + + def get_classifier_head(self, classes: list[str]): + """Extract a classifier head for the desired subset of classes.""" + if self.tfhub_version is not None: + # This is a model loaded from TFHub. + # We need to extract the weights and biases from the saved model. + vars_filepath = f'{self.model_path}/variables/variables' + else: + vars_filepath = f'{self.model_path}/savedmodel/variables/variables' + + def _get_weights_and_bias(num_classes: int): + weights = None + bias = None + for vname, vshape in tf.train.list_variables(vars_filepath): + if len(vshape) == 1 and vshape[-1] == num_classes: + if bias is None: + bias = tf.train.load_variable(vars_filepath, vname) + else: + raise ValueError('Multiple possible biases for class list.') + if len(vshape) == 2 and vshape[-1] == num_classes: + if weights is None: + weights = tf.train.load_variable(vars_filepath, vname) + else: + raise ValueError('Multiple possible weights for class list.') + if hasattr(weights, 'numpy'): + weights = weights.numpy() + if hasattr(bias, 'numpy'): + bias = bias.numpy() + return weights, bias + + class_wts = {} + for logit_key in self.class_list: + num_classes = len(self.class_list[logit_key].classes) + weights, bias = _get_weights_and_bias(num_classes) + if weights is None or bias is None: + raise ValueError( + f'No weights or bias found for {logit_key} {num_classes}' + ) + for i, k in enumerate(self.class_list[logit_key].classes): + class_wts[k] = weights[:, i], bias[i] + + wts = [] + biases = [] + found_classes = [] + for target_class in classes: + if target_class not in class_wts: + continue + wts.append(class_wts[target_class][0]) + biases.append(class_wts[target_class][1]) + found_classes.append(target_class) + print(f'Found classes: {found_classes}') + return found_classes, np.stack(wts, axis=-1), np.stack(biases, axis=-1) + + def embed(self, audio_array: np.ndarray) -> zoo_interface.InferenceOutputs: + return zoo_interface.embed_from_batch_embed_fn( + self.batch_embed, audio_array + ) + + def _nonbatchable_batch_embed(self, audio_batch: np.ndarray): + """Embed a batch of audio with an old non-batchable model.""" + all_logits = [] + all_embeddings = [] + for audio in audio_batch: + outputs = self.model.infer_tf(audio[np.newaxis, :]) + if hasattr(outputs, 'keys'): + embedding = outputs.pop('embedding') + logits = outputs.pop('label') + else: + # Assume the model output is always a (logits, embedding) twople. + logits, embedding = outputs + all_logits.append(logits) + all_embeddings.append(embedding) + all_logits = np.stack(all_logits, axis=0) + all_embeddings = np.stack(all_embeddings, axis=0) + return { + 'embedding': all_embeddings, + 'label': all_logits, + } + + def batch_embed( + self, audio_batch: np.ndarray[Any, Any] + ) -> zoo_interface.InferenceOutputs: + framed_audio = self.frame_audio( + audio_batch, self.window_size_s, self.hop_size_s + ) + framed_audio = self.normalize_audio(framed_audio, self.target_peak) + rebatched_audio = framed_audio.reshape([-1, framed_audio.shape[-1]]) + + if not self.batchable: + outputs = self._nonbatchable_batch_embed(rebatched_audio) + else: + outputs = self.model.infer_tf(rebatched_audio) + + frontend_output = None + if hasattr(outputs, 'keys'): + # Dictionary-type outputs. Arrange appropriately. + embeddings = outputs.pop('embedding') + if 'frontend' in outputs: + frontend_output = outputs.pop('frontend') + # Assume remaining outputs are all logits. + logits = outputs + elif len(outputs) == 2: + # Assume logits, embeddings outputs. + label_logits, embeddings = outputs + logits = {'label': label_logits} + else: + raise ValueError('Unexpected outputs type.') + + for k, v in logits.items(): + logits[k] = np.reshape(v, framed_audio.shape[:2] + (v.shape[-1],)) + # Unbatch and add channel dimension. + embeddings = np.reshape( + embeddings, + framed_audio.shape[:2] + + ( + 1, + embeddings.shape[-1], + ), + ) + return zoo_interface.InferenceOutputs( + embeddings=embeddings, + logits=logits, + separated_audio=None, + batched=True, + frontend=frontend_output, + ) diff --git a/chirp/projects/zoo/zoo_test.py b/chirp/projects/zoo/zoo_test.py new file mode 100644 index 00000000..fa51803a --- /dev/null +++ b/chirp/projects/zoo/zoo_test.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for mass-embedding functionality.""" + +import os +import tempfile + +from chirp.projects.zoo import models +from chirp.projects.zoo import taxonomy_model_tf +from chirp.projects.zoo import zoo_interface +from chirp.taxonomy import namespace +from ml_collections import config_dict +import numpy as np +import tensorflow as tf +import tf_keras + +from absl.testing import absltest +from absl.testing import parameterized + + +class ZooTest(parameterized.TestCase): + + def test_handcrafted_features(self): + model = models.HandcraftedFeaturesModel.beans_baseline() + + audio = np.zeros([5 * 32000], dtype=np.float32) + outputs = model.embed(audio) + # Five frames because we have 5s of audio with window 1.0 and hope 1.0. + # Beans aggrregation with mfccs creates 20 MFCC channels, and then computes + # four summary statistics for each, giving a total of 80 output channels. + self.assertSequenceEqual([5, 1, 80], outputs.embeddings.shape) + + def test_sep_embed_wrapper(self): + """Check that the joint-model wrapper works as intended.""" + separator = models.PlaceholderModel( + sample_rate=22050, + make_embeddings=False, + make_logits=False, + make_separated_audio=True, + ) + + embeddor = models.PlaceholderModel( + sample_rate=22050, + make_embeddings=True, + make_logits=True, + make_separated_audio=False, + ) + fake_config = config_dict.ConfigDict() + sep_embed = models.SeparateEmbedModel( + sample_rate=22050, + taxonomy_model_tf_config=fake_config, + separator_model_tf_config=fake_config, + separation_model=separator, + embedding_model=embeddor, + ) + audio = np.zeros(5 * 22050, np.float32) + + outputs = sep_embed.embed(audio) + # The PlaceholderModel produces one embedding per second, and we have + # five seconds of audio, with two separated channels, plus the channel + # for the raw audio. + # Note that this checks that the sample-rate conversion between the + # separation model and embedding model has worked correctly. + self.assertSequenceEqual( + outputs.embeddings.shape, [5, 3, embeddor.embedding_size] + ) + # The Sep+Embed model takes the max logits over the channel dimension. + self.assertSequenceEqual( + outputs.logits['label'].shape, [5, len(embeddor.class_list.classes)] + ) + + def test_pooled_embeddings(self): + outputs = zoo_interface.InferenceOutputs( + embeddings=np.zeros([10, 2, 8]), batched=False + ) + batched_outputs = zoo_interface.InferenceOutputs( + embeddings=np.zeros([3, 10, 2, 8]), batched=True + ) + + # Check that no-op is no-op. + non_pooled = outputs.pooled_embeddings('', '') + self.assertSequenceEqual(non_pooled.shape, outputs.embeddings.shape) + batched_non_pooled = batched_outputs.pooled_embeddings('', '') + self.assertSequenceEqual( + batched_non_pooled.shape, batched_outputs.embeddings.shape + ) + + for pooling_method in zoo_interface.POOLING_METHODS: + if pooling_method == 'squeeze': + # The 'squeeze' pooling method throws an exception if axis size is > 1. + with self.assertRaises(ValueError): + outputs.pooled_embeddings(pooling_method, '') + continue + elif pooling_method == 'flatten': + # Concatenates over the target axis. + time_pooled = outputs.pooled_embeddings(pooling_method, '') + self.assertSequenceEqual(time_pooled.shape, [2, 80]) + continue + + time_pooled = outputs.pooled_embeddings(pooling_method, '') + self.assertSequenceEqual(time_pooled.shape, [2, 8]) + batched_time_pooled = batched_outputs.pooled_embeddings( + pooling_method, '' + ) + self.assertSequenceEqual(batched_time_pooled.shape, [3, 2, 8]) + + channel_pooled = outputs.pooled_embeddings('', pooling_method) + self.assertSequenceEqual(channel_pooled.shape, [10, 8]) + batched_channel_pooled = batched_outputs.pooled_embeddings( + '', pooling_method + ) + self.assertSequenceEqual(batched_channel_pooled.shape, [3, 10, 8]) + + both_pooled = outputs.pooled_embeddings(pooling_method, pooling_method) + self.assertSequenceEqual(both_pooled.shape, [8]) + batched_both_pooled = batched_outputs.pooled_embeddings( + pooling_method, pooling_method + ) + self.assertSequenceEqual(batched_both_pooled.shape, [3, 8]) + + @parameterized.product( + model_return_type=('tuple', 'dict'), + batchable=(True, False), + ) + def test_taxonomy_model_tf(self, model_return_type, batchable): + class FakeModelFn: + output_depths = {'label': 3, 'embedding': 256} + + def infer_tf(self, audio_array): + outputs = { + k: np.zeros([audio_array.shape[0], d], dtype=np.float32) + for k, d in self.output_depths.items() + } + if model_return_type == 'tuple': + # Published Perch models v1 through v4 returned a tuple, not a dict. + return outputs['label'], outputs['embedding'] + return outputs + + class_list = { + 'label': namespace.ClassList('fake', ['alpha', 'beta', 'delta']) + } + wrapped_model = taxonomy_model_tf.TaxonomyModelTF( + sample_rate=32000, + model_path='/dev/null', + window_size_s=5.0, + hop_size_s=5.0, + model=FakeModelFn(), + class_list=class_list, + batchable=batchable, + ) + + # Check that a single frame of audio is handled properly. + outputs = wrapped_model.embed(np.zeros([5 * 32000], dtype=np.float32)) + self.assertFalse(outputs.batched) + self.assertSequenceEqual(outputs.embeddings.shape, [1, 1, 256]) + self.assertSequenceEqual(outputs.logits['label'].shape, [1, 3]) + + # Check that multi-frame audio is handled properly. + outputs = wrapped_model.embed(np.zeros([20 * 32000], dtype=np.float32)) + self.assertFalse(outputs.batched) + self.assertSequenceEqual(outputs.embeddings.shape, [4, 1, 256]) + self.assertSequenceEqual(outputs.logits['label'].shape, [4, 3]) + + # Check that a batch of single frame of audio is handled properly. + outputs = wrapped_model.batch_embed( + np.zeros([10, 5 * 32000], dtype=np.float32) + ) + self.assertTrue(outputs.batched) + self.assertSequenceEqual(outputs.embeddings.shape, [10, 1, 1, 256]) + self.assertSequenceEqual(outputs.logits['label'].shape, [10, 1, 3]) + + # Check that a batch of multi-frame audio is handled properly. + outputs = wrapped_model.batch_embed( + np.zeros([2, 20 * 32000], dtype=np.float32) + ) + self.assertTrue(outputs.batched) + self.assertSequenceEqual(outputs.embeddings.shape, [2, 4, 1, 256]) + self.assertSequenceEqual(outputs.logits['label'].shape, [2, 4, 3]) + + def test_whale_model(self): + # prereq + class FakeModel(tf_keras.Model): + """Fake implementation of the humpback_whale SavedModel API. + + The use of `tf_keras` as opposed to `tf.keras` is intentional; the models + this fakes were exported using "the pure-TensorFlow implementation of + Keras." + """ + + def __init__(self): + super().__init__() + self._sample_rate = 10000 + self._classes = ['Mn'] + self._embedder = tf_keras.layers.Dense(32) + self._classifier = tf_keras.layers.Dense(len(self._classes)) + + def call(self, spectrograms, training=False): + logits = self.logits(spectrograms) + return tf.nn.sigmoid(logits) + + @tf.function( + input_signature=[tf.TensorSpec([None, None, 1], tf.dtypes.float32)] + ) + def front_end(self, waveform): + return tf.math.abs( + tf.signal.stft( + tf.squeeze(waveform, -1), + frame_length=1024, + frame_step=300, + fft_length=128, + )[..., 1:] + ) + + @tf.function( + input_signature=[tf.TensorSpec([None, 128, 64], tf.dtypes.float32)] + ) + def features(self, spectrogram): + return self._embedder(tf.math.reduce_mean(spectrogram, axis=-2)) + + @tf.function( + input_signature=[tf.TensorSpec([None, 128, 64], tf.dtypes.float32)] + ) + def logits(self, spectrogram): + features = self.features(spectrogram) + return self._classifier(features) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, 1], tf.dtypes.float32), + tf.TensorSpec([], tf.dtypes.int64), + ] + ) + def score(self, waveform, context_step_samples): + spectrogram = self.front_end(waveform) + windows = tf.signal.frame( + spectrogram, frame_length=128, frame_step=128, axis=1 + ) + shape = tf.shape(windows) + batch_size = shape[0] + num_windows = shape[1] + frame_length = shape[2] + tf.debugging.assert_equal(frame_length, 128) + channels_len = shape[3] + logits = self.logits( + tf.reshape( + windows, (batch_size * num_windows, frame_length, channels_len) + ) + ) + return {'score': tf.nn.sigmoid(logits)} + + @tf.function(input_signature=[]) + def metadata(self): + return { + 'input_sample_rate': tf.constant( + self._sample_rate, tf.dtypes.int64 + ), + 'context_width_samples': tf.constant(39124, tf.dtypes.int64), + 'class_names': tf.constant(self._classes), + } + + # setup + fake_model = FakeModel() + batch_size = 2 + duration_seconds = 10 + sample_rate = fake_model.metadata()['input_sample_rate'] + waveform = np.random.randn( + batch_size, + sample_rate * duration_seconds, + ) + expected_frames = int(10 / 3.9124) + 1 + # Call the model to avoid "forward pass of the model is not defined" on + # save. + spectrograms = fake_model.front_end(waveform[:, :, np.newaxis]) + fake_model(spectrograms[:, :128, :]) + model_path = os.path.join(tempfile.gettempdir(), 'whale_model') + fake_model.save( + model_path, + signatures={ + 'score': fake_model.score, + 'metadata': fake_model.metadata, + 'serving_default': fake_model.score, + 'front_end': fake_model.front_end, + 'features': fake_model.features, + 'logits': fake_model.logits, + }, + ) + + with self.subTest('from_url'): + # invoke + model = models.GoogleWhaleModel.load_humpback_model(model_path) + outputs = model.batch_embed(waveform) + + # verify + self.assertTrue(outputs.batched) + self.assertSequenceEqual( + outputs.embeddings.shape, [batch_size, expected_frames, 1, 32] + ) + self.assertSequenceEqual( + outputs.logits['humpback'].shape, [batch_size, expected_frames, 1] + ) + + with self.subTest('from_config'): + # invoke + config = config_dict.ConfigDict() + config.model_url = model_path + config.sample_rate = float(sample_rate) + config.window_size_s = 3.9124 + config.peak_norm = 0.02 + model = models.GoogleWhaleModel.from_config(config) + # Let's check the regular embed this time. + outputs = model.embed(waveform[0]) + + # verify + self.assertFalse(outputs.batched) + self.assertSequenceEqual( + outputs.embeddings.shape, [expected_frames, 1, 32] + ) + self.assertSequenceEqual( + outputs.logits['multispecies_whale'].shape, [expected_frames, 1] + ) + + +if __name__ == '__main__': + absltest.main()