Skip to content

Commit

Permalink
Inference improvements.
Browse files Browse the repository at this point in the history
* Add a search noteboook.
* Move all user-facing notebooks under 'inference'.
* Rename the active learning notebook to active_learning.ipynb
* Automatically populate BootstrapConfig using the embedding run's config.json.
* Add a missing model key to inference.models.model_class_map
* Make explicit the option to disable speech filtering in the embedding notebook.

PiperOrigin-RevId: 548262946
  • Loading branch information
sdenton4 authored and copybara-github committed Jul 25, 2023
1 parent 02ad072 commit a05a268
Show file tree
Hide file tree
Showing 9 changed files with 612 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,21 @@
"\n",
"# Global imports\n",
"import collections\n",
"import os\n",
"import json\n",
"from ml_collections import config_dict\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from etils import epath\n",
"import matplotlib.pyplot as plt\n",
"import tqdm\n",
"\n",
"use_tf_gpu = True #@param\n",
"if not use_tf_gpu:\n",
" tf.config.experimental.set_visible_devices([], \"GPU\")\n",
"from chirp.inference import colab_utils\n",
"colab_utils.initialize(use_tf_gpu=True, disable_warnings=True)\n",
"\n",
"# Chirp imports\n",
"from chirp import audio_utils\n",
"from chirp import path_utils\n",
"from chirp.preprocessing import pipeline\n",
"from chirp.models import frontend\n",
"from chirp.models import metrics\n",
"from chirp.inference import models\n",
"from chirp.inference import tf_examples\n",
"from chirp.projects.bootstrap import bootstrap\n",
"from chirp.projects.bootstrap import display\n",
"from chirp.projects.bootstrap import search\n",
"from chirp.projects.multicluster import classify\n",
"from chirp.projects.multicluster import data_lib\n"
Expand All @@ -68,32 +63,16 @@
"source": [
"#@title Configure data and model locations. { vertical-output: true }\n",
"\n",
"# Path to TFRecords of unlabeled embeddings.\n",
"unlabeled_embeddings_path = '' #@param\n",
"embeddings_glob = epath.Path(unlabeled_embeddings_path) / '*'\n",
"\n",
"# Hop-size used when creating the embeddings dataset.\n",
"embedding_hop_size_s = 5.0 #@param\n",
"# Number of folder name levels in embedding file id's.\n",
"file_id_depth = 1 #@param\n",
"\n",
"# Globs for source audio files represented in the unlabeled embeddings.\n",
"# e.g., /data/project_audio/*/*.wav\n",
"audio_globs = [] #@param\n",
"# Path containing TFRecords of unlabeled embeddings.\n",
"# We will load the model which was used to compute the embeddings automatically.\n",
"embeddings_path = '/tmp/embeddings' #@param\n",
"\n",
"# Path to the labeled wav data.\n",
"# Should be in 'folder-of-folders' format - a folder with sub-folders for\n",
"# each class of interest.\n",
"# Audio in sub-folders should be wav files.\n",
"# Audio should ideally be 5s audio clips, but the system is quite forgiving.\n",
"labeled_data_path = '' #@param\n",
"\n",
"model_choice = 'perch' #@param['perch', 'birdnet']\n",
"# Path to the folder contianing the perch model, which you can get at:\n",
"# https://tfhub.dev/google/bird-vocalization-classifier\n",
"perch_path = '' #@param\n",
"# Path to a local copy of a BirdNet TFLite file.\n",
"birdnet_path = '' #@param\n"
"labeled_data_path = '' #@param\n"
]
},
{
Expand All @@ -106,40 +85,16 @@
"source": [
"#@title Load the model. { vertical-output: true }\n",
"\n",
"# Create the config and load the model given the provided information.\n",
"if model_choice == 'perch':\n",
" model_key='taxonomy_model_tf'\n",
" model_config = {\n",
" 'model_path': perch_path,\n",
" 'window_size_s': 5.0,\n",
" 'hop_size_s': embedding_hop_size_s,\n",
" 'sample_rate': 32000\n",
" }\n",
"elif model_choice == 'birdnet':\n",
" model_key='birdnet'\n",
" model_config = {\n",
" 'window_size_s': 3.0,\n",
" 'hop_size_s': embedding_hop_size_s,\n",
" 'sample_rate': 48000,\n",
" }\n",
"else:\n",
" raise ValueError(f'unknown model choice {model_choice=}')\n",
"# Get relevant info from the embedding configuration.\n",
"embeddings_path = epath.Path(embeddings_path)\n",
"with (embeddings_path / 'config.json').open() as f:\n",
" embedding_config = config_dict.ConfigDict(json.loads(f.read()))\n",
"embeddings_glob = embeddings_path / 'embeddings-*'\n",
"embedding_hop_size_s = embedding_config.embed_fn_config.model_config.hop_size_s\n",
"\n",
"config = bootstrap.BootstrapConfig(\n",
" # Path to pre-generated embeddings TFRecord files.\n",
" embeddings_glob=embeddings_glob,\n",
" embedding_hop_size_s=embedding_hop_size_s,\n",
" file_id_depth=file_id_depth,\n",
" # Globs for audio files represented in the embeddings.\n",
" audio_globs=audio_globs,\n",
"\n",
" # Path for storing annotated examples.\n",
" annotated_path=labeled_data_path,\n",
"\n",
" # Embedding model info.\n",
" # Needs to match the model used for the embeddings DB, of course...\n",
" model_key=model_key,\n",
" model_config=model_config)\n",
"config = bootstrap.BootstrapConfig.load_from_embedding_config(\n",
" embeddings_path=embeddings_path,\n",
" annotated_path=labeled_data_path)\n",
"project_state = bootstrap.BootstrapState(config)\n",
"embedding_model = project_state.embedding_model"
]
Expand All @@ -164,7 +119,7 @@
"#@title Load+Embed the Labeled Dataset. { vertical-output: true }\n",
"\n",
"# Time-pooling strategy for examples longer than the model's window size.\n",
"time_pooling = 'mean' #@param\n",
"time_pooling = 'mean' #@param\n",
"\n",
"merged = data_lib.MergedDataset(config.annotated_path,\n",
" embedding_model,\n",
Expand All @@ -188,10 +143,10 @@
"#@title Train linear model over embeddings. { vertical-output: true }\n",
"\n",
"# Number of random training examples to choose form each class.\n",
"example_per_class = 128 #@param\n",
"example_per_class = 128 #@param\n",
"\n",
"# Number of random re-trainings. Allows judging model stability.\n",
"num_seeds = 1 #@param\n",
"num_seeds = 1 #@param\n",
"\n",
"# Classifier training hyperparams.\n",
"# These should be good defaults.\n",
Expand Down Expand Up @@ -248,24 +203,26 @@
"#@title Run model on target unlabeled data. { vertical-output: true }\n",
"\n",
"# Choose the target class to work with.\n",
"target_class = '' #@param\n",
"target_class = '' #@param\n",
"# Choose a target logit; will display results close to the target.\n",
"target_logit = 2.0 #@param\n",
"target_logit = 2.0 #@param\n",
"# Number of results to display.\n",
"num_results = 25 #@param\n",
"num_results = 25 #@param\n",
"\n",
"# Create the embeddings dataset.\n",
"embeddings_ds = tf_examples.create_embeddings_dataset(unlabeled_embeddings_path)\n",
"embeddings_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"target_class_idx = merged.labels.index(target_class)\n",
"results, all_logits = search.classifer_search_embeddings_parallel(\n",
" embeddings_ds, model, target_class_idx, hop_size_s=5.0,\n",
" embeddings_ds, model, target_class_idx, hop_size_s=embedding_hop_size_s,\n",
" target_logit=target_logit, top_k=num_results\n",
")\n",
"\n",
"# Plot the histogram of logits.\n",
"_, ys, _ = plt.hist(all_logits, bins=128, density=True)\n",
"plt.xlabel(f'{target_class} logit')\n",
"plt.ylabel(f'density')\n",
"plt.ylabel('density')\n",
"# plt.yscale('log')\n",
"plt.plot([target_logit, target_logit], [0.0, np.max(ys)], 'r:')\n",
"plt.show()\n"
]
Expand All @@ -282,7 +239,7 @@
"\n",
"display_labels = merged.labels\n",
"\n",
"extra_labels = [] #@param\n",
"extra_labels = [] #@param\n",
"for label in extra_labels:\n",
" if label not in merged.labels:\n",
" display_labels += (label,)\n",
Expand All @@ -309,6 +266,58 @@
"results.write_labeled_data(\n",
" config.annotated_path, embedding_model.sample_rate)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WkpWsvQ9DGYl"
},
"outputs": [],
"source": [
"#@title Write classifier inference CSV. { vertical-output: true }\n",
"\n",
"threshold = 1.0 #@param\n",
"output_filepath = '/tmp/inference.csv' #@param\n",
"\n",
"# Create the embeddings dataset.\n",
"embeddings_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"\n",
"def classify_batch(batch):\n",
" \"\"\"Classify a batch of embeddings.\"\"\"\n",
" emb = batch[tf_examples.EMBEDDING]\n",
" emb_shape = tf.shape(emb)\n",
" flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])\n",
" logits = model(flat_emb)\n",
" logits = tf.reshape(\n",
" logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]])\n",
" # Take the maximum logit over channels.\n",
" logits = tf.reduce_max(logits, axis=-2)\n",
" batch['logits'] = logits\n",
" return batch\n",
"\n",
"inference_ds = tf_examples.create_embeddings_dataset(\n",
" embeddings_path, file_glob='embeddings-*')\n",
"inference_ds = inference_ds.map(\n",
" classify_batch, num_parallel_calls=tf.data.AUTOTUNE\n",
")\n",
"\n",
"with open(output_filepath, 'w') as f:\n",
" # Write column headers.\n",
" headers = ['filename', 'timestamp_s', 'label', 'logit']\n",
" f.write(', '.join(headers) + '\\n')\n",
" for ex in inference_ds.as_numpy_iterator():\n",
" for t in range(ex['logits'].shape[0]):\n",
" for i, label in enumerate(merged.class_names):\n",
" if ex['logits'][t, i] \u003e threshold:\n",
" offset = ex['timestamp_s'] + t * config.embedding_hop_size_s\n",
" logit = '{:.2f}'.format(ex['logits'][t, i])\n",
" row = [ex['filename'].decode('utf-8'),\n",
" '{:.2f}'.format(offset),\n",
" label, logit]\n",
" f.write(', '.join(row) + '\\n')\n"
]
}
],
"metadata": {
Expand All @@ -317,7 +326,7 @@
"build_target": "",
"kind": "local"
},
"name": "embed_notebook.ipynb",
"name": "active_learning.ipynb",
"private_outputs": true,
"provenance": [
{
Expand Down
45 changes: 45 additions & 0 deletions chirp/inference/colab_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding=utf-8
# Copyright 2023 The Chirp Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions for user-facing colab notebooks."""

import warnings

from absl import logging
from chirp import config_utils
from chirp.configs import config_globals
from chirp.inference import embed_lib
import tensorflow as tf


def initialize(use_tf_gpu: bool = True, disable_warnings: bool = True):
"""Apply notebook conveniences.
Args:
use_tf_gpu: If True, allows GPU use and sets Tensorflow to 'memory growth'
mode (instead of reserving all available GPU memory at once). If False,
Tensorflow is restricted to CPU operation. Must run before any TF
computations to be effective.
disable_warnings: If True, disables printed warnings from library code.
"""
if disable_warnings:
logging.set_verbosity(logging.ERROR)
warnings.filterwarnings('ignore')

if not use_tf_gpu:
tf.config.experimental.set_visible_devices([], 'GPU')
else:
for gpu in tf.config.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
Loading

0 comments on commit a05a268

Please sign in to comment.