diff --git a/chirp/projects/multicluster/embed_public.ipynb b/chirp/inference/active_learning.ipynb similarity index 69% rename from chirp/projects/multicluster/embed_public.ipynb rename to chirp/inference/active_learning.ipynb index a71a127a..bc45696c 100644 --- a/chirp/projects/multicluster/embed_public.ipynb +++ b/chirp/inference/active_learning.ipynb @@ -33,26 +33,21 @@ "\n", "# Global imports\n", "import collections\n", - "import os\n", + "import json\n", + "from ml_collections import config_dict\n", "import numpy as np\n", "import tensorflow as tf\n", "from etils import epath\n", "import matplotlib.pyplot as plt\n", - "import tqdm\n", "\n", - "use_tf_gpu = True #@param\n", - "if not use_tf_gpu:\n", - " tf.config.experimental.set_visible_devices([], \"GPU\")\n", + "from chirp.inference import colab_utils\n", + "colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", "\n", "# Chirp imports\n", - "from chirp import audio_utils\n", - "from chirp import path_utils\n", - "from chirp.preprocessing import pipeline\n", - "from chirp.models import frontend\n", "from chirp.models import metrics\n", - "from chirp.inference import models\n", "from chirp.inference import tf_examples\n", "from chirp.projects.bootstrap import bootstrap\n", + "from chirp.projects.bootstrap import display\n", "from chirp.projects.bootstrap import search\n", "from chirp.projects.multicluster import classify\n", "from chirp.projects.multicluster import data_lib\n" @@ -68,32 +63,16 @@ "source": [ "#@title Configure data and model locations. { vertical-output: true }\n", "\n", - "# Path to TFRecords of unlabeled embeddings.\n", - "unlabeled_embeddings_path = '' #@param\n", - "embeddings_glob = epath.Path(unlabeled_embeddings_path) / '*'\n", - "\n", - "# Hop-size used when creating the embeddings dataset.\n", - "embedding_hop_size_s = 5.0 #@param\n", - "# Number of folder name levels in embedding file id's.\n", - "file_id_depth = 1 #@param\n", - "\n", - "# Globs for source audio files represented in the unlabeled embeddings.\n", - "# e.g., /data/project_audio/*/*.wav\n", - "audio_globs = [] #@param\n", + "# Path containing TFRecords of unlabeled embeddings.\n", + "# We will load the model which was used to compute the embeddings automatically.\n", + "embeddings_path = '/tmp/embeddings' #@param\n", "\n", "# Path to the labeled wav data.\n", "# Should be in 'folder-of-folders' format - a folder with sub-folders for\n", "# each class of interest.\n", "# Audio in sub-folders should be wav files.\n", "# Audio should ideally be 5s audio clips, but the system is quite forgiving.\n", - "labeled_data_path = '' #@param\n", - "\n", - "model_choice = 'perch' #@param['perch', 'birdnet']\n", - "# Path to the folder contianing the perch model, which you can get at:\n", - "# https://tfhub.dev/google/bird-vocalization-classifier\n", - "perch_path = '' #@param\n", - "# Path to a local copy of a BirdNet TFLite file.\n", - "birdnet_path = '' #@param\n" + "labeled_data_path = '' #@param\n" ] }, { @@ -106,40 +85,16 @@ "source": [ "#@title Load the model. { vertical-output: true }\n", "\n", - "# Create the config and load the model given the provided information.\n", - "if model_choice == 'perch':\n", - " model_key='taxonomy_model_tf'\n", - " model_config = {\n", - " 'model_path': perch_path,\n", - " 'window_size_s': 5.0,\n", - " 'hop_size_s': embedding_hop_size_s,\n", - " 'sample_rate': 32000\n", - " }\n", - "elif model_choice == 'birdnet':\n", - " model_key='birdnet'\n", - " model_config = {\n", - " 'window_size_s': 3.0,\n", - " 'hop_size_s': embedding_hop_size_s,\n", - " 'sample_rate': 48000,\n", - " }\n", - "else:\n", - " raise ValueError(f'unknown model choice {model_choice=}')\n", + "# Get relevant info from the embedding configuration.\n", + "embeddings_path = epath.Path(embeddings_path)\n", + "with (embeddings_path / 'config.json').open() as f:\n", + " embedding_config = config_dict.ConfigDict(json.loads(f.read()))\n", + "embeddings_glob = embeddings_path / 'embeddings-*'\n", + "embedding_hop_size_s = embedding_config.embed_fn_config.model_config.hop_size_s\n", "\n", - "config = bootstrap.BootstrapConfig(\n", - " # Path to pre-generated embeddings TFRecord files.\n", - " embeddings_glob=embeddings_glob,\n", - " embedding_hop_size_s=embedding_hop_size_s,\n", - " file_id_depth=file_id_depth,\n", - " # Globs for audio files represented in the embeddings.\n", - " audio_globs=audio_globs,\n", - "\n", - " # Path for storing annotated examples.\n", - " annotated_path=labeled_data_path,\n", - "\n", - " # Embedding model info.\n", - " # Needs to match the model used for the embeddings DB, of course...\n", - " model_key=model_key,\n", - " model_config=model_config)\n", + "config = bootstrap.BootstrapConfig.load_from_embedding_config(\n", + " embeddings_path=embeddings_path,\n", + " annotated_path=labeled_data_path)\n", "project_state = bootstrap.BootstrapState(config)\n", "embedding_model = project_state.embedding_model" ] @@ -164,7 +119,7 @@ "#@title Load+Embed the Labeled Dataset. { vertical-output: true }\n", "\n", "# Time-pooling strategy for examples longer than the model's window size.\n", - "time_pooling = 'mean' #@param\n", + "time_pooling = 'mean' #@param\n", "\n", "merged = data_lib.MergedDataset(config.annotated_path,\n", " embedding_model,\n", @@ -188,10 +143,10 @@ "#@title Train linear model over embeddings. { vertical-output: true }\n", "\n", "# Number of random training examples to choose form each class.\n", - "example_per_class = 128 #@param\n", + "example_per_class = 128 #@param\n", "\n", "# Number of random re-trainings. Allows judging model stability.\n", - "num_seeds = 1 #@param\n", + "num_seeds = 1 #@param\n", "\n", "# Classifier training hyperparams.\n", "# These should be good defaults.\n", @@ -248,24 +203,26 @@ "#@title Run model on target unlabeled data. { vertical-output: true }\n", "\n", "# Choose the target class to work with.\n", - "target_class = '' #@param\n", + "target_class = '' #@param\n", "# Choose a target logit; will display results close to the target.\n", - "target_logit = 2.0 #@param\n", + "target_logit = 2.0 #@param\n", "# Number of results to display.\n", - "num_results = 25 #@param\n", + "num_results = 25 #@param\n", "\n", "# Create the embeddings dataset.\n", - "embeddings_ds = tf_examples.create_embeddings_dataset(unlabeled_embeddings_path)\n", + "embeddings_ds = tf_examples.create_embeddings_dataset(\n", + " embeddings_path, file_glob='embeddings-*')\n", "target_class_idx = merged.labels.index(target_class)\n", "results, all_logits = search.classifer_search_embeddings_parallel(\n", - " embeddings_ds, model, target_class_idx, hop_size_s=5.0,\n", + " embeddings_ds, model, target_class_idx, hop_size_s=embedding_hop_size_s,\n", " target_logit=target_logit, top_k=num_results\n", ")\n", "\n", "# Plot the histogram of logits.\n", "_, ys, _ = plt.hist(all_logits, bins=128, density=True)\n", "plt.xlabel(f'{target_class} logit')\n", - "plt.ylabel(f'density')\n", + "plt.ylabel('density')\n", + "# plt.yscale('log')\n", "plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')\n", "plt.show()\n" ] @@ -282,7 +239,7 @@ "\n", "display_labels = merged.labels\n", "\n", - "extra_labels = [] #@param\n", + "extra_labels = [] #@param\n", "for label in extra_labels:\n", " if label not in merged.labels:\n", " display_labels += (label,)\n", @@ -309,6 +266,58 @@ "results.write_labeled_data(\n", " config.annotated_path, embedding_model.sample_rate)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WkpWsvQ9DGYl" + }, + "outputs": [], + "source": [ + "#@title Write classifier inference CSV. { vertical-output: true }\n", + "\n", + "threshold = 1.0 #@param\n", + "output_filepath = '/tmp/inference.csv' #@param\n", + "\n", + "# Create the embeddings dataset.\n", + "embeddings_ds = tf_examples.create_embeddings_dataset(\n", + " embeddings_path, file_glob='embeddings-*')\n", + "\n", + "def classify_batch(batch):\n", + " \"\"\"Classify a batch of embeddings.\"\"\"\n", + " emb = batch[tf_examples.EMBEDDING]\n", + " emb_shape = tf.shape(emb)\n", + " flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])\n", + " logits = model(flat_emb)\n", + " logits = tf.reshape(\n", + " logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])\n", + " # Take the maximum logit over channels.\n", + " logits = tf.reduce_max(logits, axis=-2)\n", + " batch['logits'] = logits\n", + " return batch\n", + "\n", + "inference_ds = tf_examples.create_embeddings_dataset(\n", + " embeddings_path, file_glob='embeddings-*')\n", + "inference_ds = inference_ds.map(\n", + " classify_batch, num_parallel_calls=tf.data.AUTOTUNE\n", + ")\n", + "\n", + "with open(output_filepath, 'w') as f:\n", + " # Write column headers.\n", + " headers = ['filename', 'timestamp_s', 'label', 'logit']\n", + " f.write(', '.join(headers) + '\\n')\n", + " for ex in inference_ds.as_numpy_iterator():\n", + " for t in range(ex['logits'].shape[0]):\n", + " for i, label in enumerate(merged.class_names):\n", + " if ex['logits'][t, i] \u003e threshold:\n", + " offset = ex['timestamp_s'] + t * config.embedding_hop_size_s\n", + " logit = '{:.2f}'.format(ex['logits'][t, i])\n", + " row = [ex['filename'].decode('utf-8'),\n", + " '{:.2f}'.format(offset),\n", + " label, logit]\n", + " f.write(', '.join(row) + '\\n')\n" + ] } ], "metadata": { @@ -317,7 +326,7 @@ "build_target": "", "kind": "local" }, - "name": "embed_notebook.ipynb", + "name": "active_learning.ipynb", "private_outputs": true, "provenance": [ { diff --git a/chirp/inference/colab_utils.py b/chirp/inference/colab_utils.py new file mode 100644 index 00000000..2e71d5b6 --- /dev/null +++ b/chirp/inference/colab_utils.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2023 The Chirp Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper functions for user-facing colab notebooks.""" + +import warnings + +from absl import logging +from chirp import config_utils +from chirp.configs import config_globals +from chirp.inference import embed_lib +import tensorflow as tf + + +def initialize(use_tf_gpu: bool = True, disable_warnings: bool = True): + """Apply notebook conveniences. + + Args: + use_tf_gpu: If True, allows GPU use and sets Tensorflow to 'memory growth' + mode (instead of reserving all available GPU memory at once). If False, + Tensorflow is restricted to CPU operation. Must run before any TF + computations to be effective. + disable_warnings: If True, disables printed warnings from library code. + """ + if disable_warnings: + logging.set_verbosity(logging.ERROR) + warnings.filterwarnings('ignore') + + if not use_tf_gpu: + tf.config.experimental.set_visible_devices([], 'GPU') + else: + for gpu in tf.config.list_physical_devices('GPU'): + tf.config.experimental.set_memory_growth(gpu, True) diff --git a/chirp/inference/embed_audio.ipynb b/chirp/inference/embed_audio.ipynb index 0f760abc..5330a1c6 100644 --- a/chirp/inference/embed_audio.ipynb +++ b/chirp/inference/embed_audio.ipynb @@ -25,18 +25,14 @@ "#@title Imports. { vertical-output: true }\n", "\n", "# Global imports\n", - "import collections\n", - "import os\n", + "from etils import epath\n", "import numpy as np\n", "import tensorflow as tf\n", - "from etils import epath\n", - "import matplotlib.pyplot as plt\n", "import tqdm\n", + "from chirp.inference import colab_utils\n", + "colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", "\n", - "use_tf_gpu = True #@param\n", - "if not use_tf_gpu:\n", - " tf.config.experimental.set_visible_devices([], \"GPU\")\n", - "\n", + "from chirp import audio_utils\n", "from chirp import config_utils\n", "from chirp.configs import config_globals\n", "from chirp.inference import embed_lib\n", @@ -54,29 +50,51 @@ "#@title Configuration. { vertical-output: true }\n", "\n", "# Name of base configuration file in `chirp/inference/configs`\n", - "config_key = 'raw_soundscapes' #@param\n", + "config_key = 'raw_soundscapes' #@param\n", "config = embed_lib.get_config(config_key)\n", "config = config_utils.parse_config(config, config_globals.get_globals())\n", "\n", "# Here we adjust the input and output targets.\n", - "config.output_dir = '/tmp/testrun' #@param\n", - "config.source_file_patterns = [''] #@param\n", + "config.output_dir = '' #@param\n", + "config.source_file_patterns = [''] #@param\n", "\n", - "# Location of Perch model\n", - "model_path = '' #@param\n", + "# Define the model\n", + "model_choice = 'perch'\n", + "# For Perch, the directory containing the model.\n", + "# For BirdNET, point to the specific tflite file.\n", + "model_path = '' #@param\n", "config.embed_fn_config.model_config.model_path = model_path\n", + "if model_choice == 'perch':\n", + " config.embed_fn_config.model_config.window_size_s = 5.0\n", + " config.embed_fn_config.model_config.hop_size_s = 5.0\n", + " config.embed_fn_config.model_config.sample_rate = 32000\n", + "elif model_choice == 'birdnet':\n", + " config.embed_fn_config.model_config.window_size_s = 3.0\n", + " config.embed_fn_config.model_config.hop_size_s = 3.0\n", + " config.embed_fn_config.model_config.sample_rate = 16000\n", + "\n", + "# Only write embeddings to reduce size.\n", + "config.embed_fn_config.write_embeddings = True\n", + "config.embed_fn_config.write_logits = False\n", + "config.embed_fn_config.write_separated_audio = False\n", + "config.embed_fn_config.write_raw_audio = False\n", + "\n", "\n", "# Embedding windows are broken up into groups, typically one minute in length.\n", "# This lets us limit input size to the model, track progres and\n", "# recover from failures more easily.\n", - "config.shard_len_s = 60 #@param\n", - "config.num_shards_per_file = 10 #@param\n", + "config.shard_len_s = 60 #@param\n", + "config.num_shards_per_file = 1 #@param\n", "\n", "# Number of parent directories to include in the filename.\n", "config.embed_fn_config.file_id_depth = 1\n", "\n", "# Number of TF Record files to create.\n", - "tf_record_shards = 10 #@param" + "config.tf_record_shards = 10 #@param\n", + "\n", + "# Speech filter threshold for YamNet.\n", + "# Set to a value between 0 and 1, or -1 to disable.\n", + "config.embed_fn_config.speech_filter_threshold = -1.0\n" ] }, { @@ -104,7 +122,7 @@ "# Set up the embedding function, including loading models.\n", "embed_fn = embed_lib.EmbedFn(**config.embed_fn_config)\n", "print('\\n\\nLoading model(s)...')\n", - "%time embed_fn.setup()\n", + "embed_fn.setup()\n", "\n", "print('\\n\\nTest-run of model...')\n", "# We run the test twice - the first run optimizes the execution, and\n", @@ -112,10 +130,8 @@ "window_size_s = config.embed_fn_config.model_config.window_size_s\n", "sr = config.embed_fn_config.model_config.sample_rate\n", "z = np.zeros([int(sr * window_size_s)])\n", - "print(' Cold-start timing:')\n", - "%time unused = embed_fn.embedding_model.embed(z)\n", - "print(' Typical run timing:')\n", - "%time unused = embed_fn.embedding_model.embed(z)\n" + "embed_fn.embedding_model.embed(z)\n", + "print('Setup complete!')" ] }, { @@ -126,13 +142,16 @@ }, "outputs": [], "source": [ - "#@title Run embedding. { vertical-output: true }\n", + "#@title Run embedding. (safe) { vertical-output: true }\n", + "\n", + "# Loads audio files one-by-one using methods which will tend not to fail\n", + "# if the target files have minor problems (eg, wrong length metadata).\n", "\n", "embed_fn.min_audio_s = 1.0\n", "record_file = (output_dir / 'embeddings.tfrecord').as_posix()\n", "succ, fail = 0, 0\n", - "with EmbeddingsTFRecordMultiWriter(\n", - " output_dir=output_dir, num_files=tf_record_shards) as file_writer:\n", + "with tf_examples.EmbeddingsTFRecordMultiWriter(\n", + " output_dir=output_dir, num_files=config.tf_record_shards) as file_writer:\n", " for source_info in tqdm.tqdm(source_infos):\n", " examples = embed_fn.process(source_info=source_info)\n", " if examples is None:\n", @@ -153,6 +172,63 @@ " print(ex['embedding'].shape)\n", " break" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5JbRna2tnGj5" + }, + "outputs": [], + "source": [ + "#@title Run embedding. (fast) { vertical-output: true }\n", + "\n", + "# Uses multiple threads to load audio before embedding.\n", + "# This tends to be faster, but can fail if any audio files are corrupt.\n", + "\n", + "embed_fn.min_audio_s = 1.0\n", + "record_file = (output_dir / 'embeddings.tfrecord').as_posix()\n", + "succ, fail = 0, 0\n", + "\n", + "audio_iterator = audio_utils.multi_load_audio_window(\n", + " filepaths=[s.filepath for s in source_infos],\n", + " offsets=[s.shard_num * s.shard_len_s for s in source_infos],\n", + " sample_rate=config.embed_fn_config.model_config.sample_rate,\n", + " window_size_s=config.shard_len_s,\n", + ")\n", + "with tf_examples.EmbeddingsTFRecordMultiWriter(\n", + " output_dir=output_dir, num_files=config.tf_record_shards) as file_writer:\n", + " for source_info, audio in tqdm.tqdm(\n", + " zip(source_infos, audio_iterator), total=len(source_infos)):\n", + " file_id = source_info.file_id(config.embed_fn_config.file_id_depth)\n", + " offset_s = source_info.shard_num * source_info.shard_len_s\n", + " example = embed_fn.audio_to_example(file_id, offset_s, audio)\n", + " if example is None:\n", + " fail += 1\n", + " continue\n", + " file_writer.write(example.SerializeToString())\n", + " succ += 1\n", + " file_writer.flush()\n", + "print(f'\\n\\nSuccessfully processed {succ} source_infos, failed {fail} times.')\n", + "\n", + "fns = [fn for fn in output_dir.glob('embeddings-*')]\n", + "ds = tf.data.TFRecordDataset(fns)\n", + "parser = tf_examples.get_example_parser()\n", + "ds = ds.map(parser)\n", + "for ex in ds.as_numpy_iterator():\n", + " print(ex['filename'])\n", + " print(ex['embedding'].shape)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8Yi4nL7JtNvI" + }, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/chirp/inference/embed_lib.py b/chirp/inference/embed_lib.py index 76f0071d..d16abe7b 100644 --- a/chirp/inference/embed_lib.py +++ b/chirp/inference/embed_lib.py @@ -16,6 +16,7 @@ """Create embeddings for an audio corpus.""" import dataclasses +import json import os from typing import Any, Sequence @@ -47,6 +48,12 @@ class SourceInfo: shard_num: int shard_len_s: float + def file_id(self, file_id_depth: int) -> str: + file_id = epath.Path( + *epath.Path(self.filepath).parts[-(file_id_depth + 1) :] + ).as_posix() + return file_id + def create_source_infos( source_file_patterns: Sequence[str], @@ -190,6 +197,25 @@ def _log_exception(self, source_info, exception, counter_name): exception, ) + def audio_to_example( + self, file_id: str, timestamp_offset_s: float, audio: np.ndarray + ) -> tf.train.Example: + """Embed audio and create a TFExample.""" + if self.embedding_model is None: + raise ValueError('Embedding model undefined.') + model_outputs = self.embedding_model.embed(audio) + example = tf_examples.model_outputs_to_tf_example( + model_outputs=model_outputs, + file_id=file_id, + audio=audio, + timestamp_offset_s=timestamp_offset_s, + write_raw_audio=self.write_raw_audio, + write_separated_audio=self.write_separated_audio, + write_embeddings=self.write_embeddings, + write_logits=self.write_logits, + ) + return example + @beam.typehints.with_output_types(Any) def process(self, source_info: SourceInfo, crop_s: float = -1.0): """Process a source. @@ -202,9 +228,7 @@ def process(self, source_info: SourceInfo, crop_s: float = -1.0): Returns: A TFExample. """ - file_id = epath.Path( - *epath.Path(source_info.filepath).parts[-(self.file_id_depth + 1) :] - ).as_posix() + file_id = source_info.file_id(self.file_id_depth) logging.info('...loading audio (%s)', source_info.filepath) timestamp_offset_s = source_info.shard_num * source_info.shard_len_s @@ -255,17 +279,7 @@ def process(self, source_info: SourceInfo, crop_s: float = -1.0): logging.info( '...creating embeddings (%s / %d)', file_id, timestamp_offset_s ) - model_outputs = self.embedding_model.embed(audio) - example = tf_examples.model_outputs_to_tf_example( - model_outputs=model_outputs, - file_id=file_id, - audio=audio, - timestamp_offset_s=timestamp_offset_s, - write_raw_audio=self.write_raw_audio, - write_separated_audio=self.write_separated_audio, - write_embeddings=self.write_embeddings, - write_logits=self.write_logits, - ) + example = self.audio_to_example(file_id, timestamp_offset_s, audio) beam.metrics.Metrics.counter('beaminference', 'examples_processed').inc() return [example] @@ -299,6 +313,14 @@ def maybe_write_config(parsed_config, output_dir): f.write(config_json) +def load_embedding_config(embeddings_path): + """Loads the configuration to generate unlabeled embeddings.""" + embeddings_path = epath.Path(embeddings_path) + with (embeddings_path / 'config.json').open() as f: + embedding_config = config_dict.ConfigDict(json.loads(f.read())) + return embedding_config + + def build_run_pipeline(base_pipeline, output_dir, source_infos, embed_fn): """Create and run a beam pipeline.""" _ = ( diff --git a/chirp/inference/models.py b/chirp/inference/models.py index f36fd300..60bd8d62 100644 --- a/chirp/inference/models.py +++ b/chirp/inference/models.py @@ -38,9 +38,11 @@ def model_class_map() -> dict[str, Any]: return { 'taxonomy_model_tf': TaxonomyModelTF, 'separator_model_tf': SeparatorModelTF, + 'birb_separator_model_tf1': BirbSepModelTF1, 'birdnet': BirdNet, 'placeholder_model': PlaceholderModel, 'separate_embed_model': SeparateEmbedModel, + 'tfhub_model': TFHubModel, } diff --git a/chirp/inference/search_embeddings.ipynb b/chirp/inference/search_embeddings.ipynb new file mode 100644 index 00000000..e05fe795 --- /dev/null +++ b/chirp/inference/search_embeddings.ipynb @@ -0,0 +1,283 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TbsWuOMtug8-" + }, + "outputs": [], + "source": [ + "#@title Imports. { vertical-output: true }\n", + "\n", + "# Disable annoying warnings.\n", + "\n", + "# Global imports\n", + "import json\n", + "from ml_collections import config_dict\n", + "import numpy as np\n", + "from etils import epath\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from chirp.inference import colab_utils\n", + "colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n", + "\n", + "from chirp import audio_utils\n", + "from chirp.inference import models\n", + "from chirp.projects.bootstrap import bootstrap\n", + "from chirp.projects.bootstrap import search\n", + "from chirp.projects.bootstrap import display\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZAqbraNjuxYr" + }, + "outputs": [], + "source": [ + "#@title Configuration and Setup. { vertical-output: true }\n", + "\n", + "# Path to embeddings of unlabeled data.\n", + "embeddings_path = '/tmp/embeddings' #@param\n", + "\n", + "# Path for storing annotated examples.\n", + "labeled_data_path = '/tmp/labeled_data' #@param\n", + "\n", + "separation_model_key = '' #@param\n", + "separation_model_path = '' #@param\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uyvwPYCY3sb5" + }, + "outputs": [], + "source": [ + "#@title Setup. { vertical-output: true }\n", + "\n", + "# Get relevant info from the embedding configuration.\n", + "embeddings_path = epath.Path(embeddings_path)\n", + "with (embeddings_path / 'config.json').open() as f:\n", + " embedding_config = config_dict.ConfigDict(json.loads(f.read()))\n", + "embeddings_glob = embeddings_path / 'embeddings-*'\n", + "\n", + "# Extract the embedding model config from the embedding_config.\n", + "if embedding_config.embed_fn_config.model_key == 'separate_embed_model':\n", + " # If a separation model was applied, get the embedding model config only.\n", + " model_key = 'taxonomy_model_tf_config'\n", + " model_config = embedding_config.embed_fn_config.model_config.taxonomy_model_tf_config\n", + "else:\n", + " model_key = embedding_config.embed_fn_config.model_key\n", + " model_config = embedding_config.embed_fn_config.model_config\n", + "\n", + "config = bootstrap.BootstrapConfig.load_from_embedding_config(\n", + " embeddings_path=embeddings_path,\n", + " annotated_path=labeled_data_path)\n", + "project_state = bootstrap.BootstrapState(config)\n", + "\n", + "# Load separation model.\n", + "if separation_model_key == 'separator_model_tf':\n", + " separator = models.model_class_map()[separation_model_key](\n", + " sample_rate=32000,\n", + " model_path=separation_model_path,\n", + " frame_size=32000,\n", + " )\n", + "else:\n", + " print('No separation model loaded.')\n", + " separator = None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NM_i9Hur2mo8" + }, + "source": [ + "## Query Creation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "C7hW22722pCW" + }, + "outputs": [], + "source": [ + "#@title Load query audio. { vertical-output: true }\n", + "\n", + "# Point to an audio file of your choice.\n", + "audio_path = 'gs://chirp-public-bucket/notela-blog-post/yetvir-soundscape.mp3' #@param\n", + "# Muck around with manual selection of the query start time...\n", + "start_s = 1 #@param\n", + "\n", + "window_s = config.model_config['window_size_s']\n", + "sample_rate = config.model_config['sample_rate']\n", + "audio = audio_utils.load_audio(audio_path, sample_rate)\n", + "\n", + "# Display the full file.\n", + "display.plot_audio_melspec(audio, sample_rate)\n", + "\n", + "# Display the selected window.\n", + "print('-' * 80)\n", + "print('Selected audio window:')\n", + "# TODO(tomdenton): Pad or shift if too close to the end of the file.\n", + "st = int(start_s * sample_rate)\n", + "end = int(st + window_s * sample_rate)\n", + "audio_window = audio[st:end]\n", + "display.plot_audio_melspec(audio_window, sample_rate)\n", + "\n", + "query_audio = audio_window\n", + "sep_outputs = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aPVa1xFRvy1X" + }, + "outputs": [], + "source": [ + "#@title Separate the target audio window { vertical-output: true }\n", + "\n", + "if separator is not None:\n", + " sep_outputs = separator.embed(audio_window)\n", + "\n", + " for c in range(sep_outputs.separated_audio.shape[0]):\n", + " print(f'Channel {c}')\n", + " display.plot_audio_melspec(sep_outputs.separated_audio[c, :], sample_rate)\n", + " print('-' * 80)\n", + "else:\n", + " sep_outputs = None\n", + " print('No separation model loaded.')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Hr42XJFnw_w6" + }, + "outputs": [], + "source": [ + "#@title Select the query channel. { vertical-output: true }\n", + "\n", + "query_label = 'my_label' #@param\n", + "query_channel = -1 #@param\n", + "\n", + "if query_channel \u003c 0 or sep_outputs is None:\n", + " query_audio = audio_window\n", + "else:\n", + " query_audio = sep_outputs.separated_audio[query_channel].copy()\n", + "\n", + "display.plot_audio_melspec(query_audio, sample_rate)\n", + "\n", + "outputs = project_state.embedding_model.embed(query_audio)\n", + "query = outputs.pooled_embeddings('first', 'first')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KII-L3uoyKT8" + }, + "outputs": [], + "source": [ + "#@title Run Top-K Search. { vertical-output: true }\n", + "\n", + "# Target distance for search results.\n", + "# This lets us try to hone in on a 'classifier boundary' instead of just\n", + "# looking at the closest matches.\n", + "target_dist = 0 #@param\n", + "\n", + "# Number of search results to capture.\n", + "top_k = 10 #@param\n", + "\n", + "ds = project_state.create_embeddings_dataset()\n", + "results, all_distances = search.search_embeddings_parallel(\n", + " ds, query[np.newaxis, np.newaxis, :], hop_size_s=model_config.hop_size_s,\n", + " top_k=top_k, target_dist=target_dist)\n", + "\n", + "# Plot histogram of distances\n", + "_, ys, _ = plt.hist(all_distances, bins=128, density=True)\n", + "hit_distances = [r.distance for r in results.search_results]\n", + "plt.scatter(hit_distances, np.zeros_like(hit_distances), marker='|',\n", + " color='r', alpha=0.5)\n", + "\n", + "plt.xlabel('distance')\n", + "plt.ylabel('density')\n", + "if target_dist \u003e 0:\n", + " plt.plot([target_dist, target_dist], [0.0, np.max(ys)], 'r:')\n", + "min_dist = np.min(all_distances)\n", + "plt.plot([min_dist, min_dist], [0.0, np.max(ys)], 'g:')\n", + "\n", + "plt.show()\n", + "\n", + "# Compute the proportion of files with min_dist \u003c target_dist\n", + "hit_percentage = (np.sum(\n", + " [d \u003c target_dist for d in all_distances]) / all_distances.shape[0])\n", + "print(f'file min_dist\u003ctarget percentage : {hit_percentage:5.3f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x8bsc-wLooAw" + }, + "outputs": [], + "source": [ + "#@title Display results. { vertical-output: true }\n", + "\n", + "display.display_search_results(\n", + " results, sample_rate, project_state.source_map,\n", + " checkbox_labels=[query_label, 'unknown'],\n", + " max_workers=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ubzz56yqosXZ" + }, + "outputs": [], + "source": [ + "#@title Write annotated examples. { vertical-output: true }\n", + "\n", + "results.write_labeled_data(config.annotated_path,\n", + " project_state.embedding_model.sample_rate)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "private_outputs": true, + "provenance": [ + { + "file_id": "1HQNRQL-pQu-9kuZzKI9Fkiy6R7jYGKir", + "timestamp": 1689436763547 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/chirp/projects/bootstrap/bootstrap.py b/chirp/projects/bootstrap/bootstrap.py index 8dac7657..e70d888d 100644 --- a/chirp/projects/bootstrap/bootstrap.py +++ b/chirp/projects/bootstrap/bootstrap.py @@ -48,16 +48,9 @@ def create_embeddings_dataset(self): """Create a TF Dataset of the embeddings.""" if self.embeddings_dataset: return self.embeddings_dataset - if '*' not in self.config.embeddings_glob: - ds = tf_examples.create_embeddings_dataset(self.config.embeddings_glob) - else: - # find the first segment with a *. - dirs = self.config.embeddings_glob.split('/') - has_wildcard = ['*' in d for d in dirs] - first_wildcard = has_wildcard.index(True) - dirs = '/'.join(dirs[:first_wildcard]) - glob = '/'.join(dirs[first_wildcard:]) - ds = tf_examples.create_embeddings_dataset(dirs, glob) + ds = tf_examples.create_embeddings_dataset( + self.config.embeddings_path, 'embeddings-*' + ) self.embeddings_dataset = ds return ds @@ -86,15 +79,40 @@ class BootstrapConfig: """Configuration for Search Bootstrap project.""" # Embeddings dataset info. - embeddings_glob: str - embedding_hop_size_s: float - file_id_depth: int - audio_globs: Sequence[str] | None + embeddings_path: str # Annotations info. - # TODO(tomdenton): Write handling for the annotated data. annotated_path: str - # Model info. Should match the model used for creating embeddings. - model_key: str - model_config: config_dict.ConfigDict + # The following are populated automatically from the embedding config. + embedding_hop_size_s: float | None = None + file_id_depth: int | None = None + audio_globs: Sequence[str] | None = None + model_key: str | None = None + model_config: config_dict.ConfigDict | None = None + + @classmethod + def load_from_embedding_config( + cls, embeddings_path: str, annotated_path: str + ): + """Instantiate from a configuration written alongside embeddings.""" + embedding_config = embed_lib.load_embedding_config(embeddings_path) + embed_fn_config = embedding_config.embed_fn_config + + # Extract the embedding model config from the embedding_config. + if embed_fn_config.model_key == 'separate_embed_model': + # If a separation model was applied, get the embedding model config only. + model_key = 'taxonomy_model_tf_config' + model_config = embed_fn_config.model_config.taxonomy_model_tf_config + else: + model_key = embed_fn_config.model_key + model_config = embed_fn_config.model_config + return BootstrapConfig( + embeddings_path=embeddings_path, + annotated_path=annotated_path, + model_key=model_key, + model_config=model_config, + embedding_hop_size_s=model_config.hop_size_s, + file_id_depth=embed_fn_config.file_id_depth, + audio_globs=embedding_config.source_file_patterns, + ) diff --git a/chirp/projects/bootstrap/display.py b/chirp/projects/bootstrap/display.py index 13399fb9..47f561d2 100644 --- a/chirp/projects/bootstrap/display.py +++ b/chirp/projects/bootstrap/display.py @@ -107,8 +107,23 @@ def display_search_results( print(f'offset: {offset_s:6.2f}') print(f'distance: {(r.distance + results.distance_offset):6.2f}') label_widgets = [] + + def button_callback(x): + x.value = not x.value + if x.value: + x.button_style = 'success' + else: + x.button_style = '' + for lbl in checkbox_labels: - check = ipywidgets.Checkbox(description=lbl, value=False) + check = ipywidgets.Button( + description=lbl, + disabled=False, + button_style='', + ) + check.value = False + check.on_click(button_callback) + label_widgets.append(check) ipy_display(check) # Attach audio and widgets to the SearchResult. diff --git a/chirp/projects/bootstrap/search.py b/chirp/projects/bootstrap/search.py index a6d81100..40364e22 100644 --- a/chirp/projects/bootstrap/search.py +++ b/chirp/projects/bootstrap/search.py @@ -15,6 +15,7 @@ """Tools for searching an embeddings dataset.""" +import collections import dataclasses from typing import Any, Callable, List, Sequence @@ -92,6 +93,7 @@ def sort(self): def write_labeled_data(self, labeled_data_path: str, sample_rate: int): """Write labeled results to the labeled data collection.""" labeled_data_path = epath.Path(labeled_data_path) + counts = collections.defaultdict(int) for r in self.search_results: labels = [ch.description for ch in r.label_widgets if ch.value] if not labels: @@ -104,6 +106,9 @@ def write_labeled_data(self, labeled_data_path: str, sample_rate: int): output_path.mkdir(parents=True, exist_ok=True) output_filepath = output_path / output_filename wavfile.write(output_filepath, sample_rate, r.audio) + counts[label] += 1 + for label, count in counts.items(): + print(f'Wrote {count} examples for label {label}') @dataclasses.dataclass @@ -213,9 +218,12 @@ def classifer_search_embeddings_parallel( def classify_batch(batch): emb = batch[tf_examples.EMBEDDING] - # This seems to 'just work' when the classifier input shape is [None, D] - # and the embeddings shape is [B, C, D]. - logits = embeddings_classifier(emb) + emb_shape = tf.shape(emb) + flat_emb = tf.reshape(emb, [-1, emb_shape[-1]]) + logits = embeddings_classifier(flat_emb) + logits = tf.reshape( + logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]] + ) # Restrict to target class. logits = logits[..., target_index] # Take the maximum logit over channels.