From 8816bf219dbbfc08dce10d74313e72357e011d34 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Tue, 1 Oct 2024 11:33:07 -0700 Subject: [PATCH] Agile2 agile modeling notebook. PiperOrigin-RevId: 681101214 --- .../projects/agile2/2_agile_modeling_v2.ipynb | 314 ++++++++++++++++++ chirp/projects/agile2/audio_loader.py | 69 ++++ chirp/projects/agile2/classifier.py | 2 +- chirp/projects/agile2/classifier_data.py | 21 +- .../agile2/tests/classifier_data_test.py | 23 ++ .../agile2/tests/embedding_display_test.py | 18 +- chirp/projects/hoplite/score_functions.py | 31 +- 7 files changed, 460 insertions(+), 18 deletions(-) create mode 100644 chirp/projects/agile2/2_agile_modeling_v2.ipynb create mode 100644 chirp/projects/agile2/audio_loader.py diff --git a/chirp/projects/agile2/2_agile_modeling_v2.ipynb b/chirp/projects/agile2/2_agile_modeling_v2.ipynb new file mode 100644 index 00000000..2294f84b --- /dev/null +++ b/chirp/projects/agile2/2_agile_modeling_v2.ipynb @@ -0,0 +1,314 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AXQAcreKedWU" + }, + "outputs": [], + "source": [ + "#@title Imports. { vertical-output: true }\n", + "\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "\n", + "from chirp.projects.agile2 import audio_loader\n", + "from chirp.projects.agile2 import classifier\n", + "from chirp.projects.agile2 import classifier_data\n", + "from chirp.projects.agile2 import embedding_display\n", + "from chirp.projects.hoplite import brutalism\n", + "from chirp.projects.hoplite import score_functions\n", + "from chirp.projects.hoplite import sqlite_impl\n", + "from chirp.projects.zoo import models\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fz63sWKEedWU" + }, + "outputs": [], + "source": [ + "#@title Load model and connect to database. { vertical-output: true }\n", + "\n", + "#@markdown Location of database containing audio embeddings.\n", + "db_path = '' #@param {type:'string'}\n", + "#@markdown Identifier (eg, name) to attach to labels produced during validation.\n", + "annotator_id = 'linnaeus' #@param {type:'string'}\n", + "\n", + "db = sqlite_impl.SQLiteGraphSearchDB.create(db_path)\n", + "db_model_config = db.get_metadata('model_config')\n", + "embed_config = db.get_metadata('embed_config')\n", + "model_class = models.model_class_map()[db_model_config.model_key]\n", + "embedding_model = model_class.from_config(db_model_config.model_config)\n", + "\n", + "audio_filepath_loader = audio_loader.make_filepath_loader(\n", + " audio_globs=embed_config.audio_globs,\n", + " window_size_s=embedding_model.window_size_s,\n", + " sample_rate_hz=embedding_model.sample_rate,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BVdLJJd9gnjo" + }, + "source": [ + "# Search" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7ig3L5dsy3mr" + }, + "outputs": [], + "source": [ + "#@title Load query audio. { vertical-output: true }\n", + "\n", + "#@markdown The `query_uri` can be a URL, filepath, or Xeno-Canto ID\n", + "#@markdown (like `xc777802`, containing an Eastern Whipbird (`easwhi1`)).\n", + "query_uri = 'xc777802' #@param {type:'string'}\n", + "query_label = 'easwhi1' #@param {type:'string'}\n", + "\n", + "query = embedding_display.QueryDisplay(\n", + " uri=query_uri, offset_s=0.0, window_size_s=5.0, sample_rate_hz=32000)\n", + "_ = query.display_interactive()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iHUJ_NwQWZNB" + }, + "outputs": [], + "source": [ + "#@title Embed the Query and Search. { vertical-output: true }\n", + "\n", + "#@markdown Number of results to find and display.\n", + "num_results = 50 #@param\n", + "query_embedding = embedding_model.embed(\n", + " query.get_audio_window()).embeddings[0, 0]\n", + "#@markdown Number of (randomly selected) database entries to search over.\n", + "sample_size = 1_000_000 #@param\n", + "\n", + "#@markdown If checked, search for examples\n", + "#@markdown near a particular target score.\n", + "target_sampling = False #@param {type: 'boolean'}\n", + "\n", + "#@markdown When target sampling, target this score.\n", + "target_score = -1.0 #@param\n", + "\n", + "if not target_sampling:\n", + " target_score = None\n", + "score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n", + "results, all_scores = brutalism.threaded_brute_search(\n", + " db, query_embedding, num_results, score_fn=score_fn,\n", + " sample_size=sample_size)\n", + "\n", + "# TODO(tomdenton): Better histogram when target sampling.\n", + "_ = plt.hist(all_scores, bins=100)\n", + "hit_scores = [r.sort_score for r in results.search_results]\n", + "plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n", + " color='r', alpha=0.5)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Y21fWjEwXj68" + }, + "outputs": [], + "source": [ + "#@title Display Results. { vertical-output: true }\n", + "\n", + "display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(\n", + " results, db, sample_rate_hz=32000, frame_rate=100,\n", + " audio_loader=audio_filepath_loader)\n", + "display_results.display(positive_labels=[query_label])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "G3sIkOqlXzKB" + }, + "outputs": [], + "source": [ + "#@title Save data labels. { vertical-output: true }\n", + "\n", + "prev_lbls, new_lbls = 0, 0\n", + "for lbl in display_results.harvest_labels(annotator_id):\n", + " check = db.insert_label(lbl, skip_duplicates=True)\n", + " new_lbls += check\n", + " prev_lbls += (1 - check)\n", + "print('\\nnew_lbls: ', new_lbls)\n", + "print('\\nprev_lbls: ', prev_lbls)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o65wpjvyYft-" + }, + "source": [ + "# Classify" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qtsJkgcPYg6z" + }, + "outputs": [], + "source": [ + "#@title Classifier training. { vertical-output: true }\n", + "\n", + "#@markdown Set of labels to classify. If None, auto-populated from the DB.\n", + "target_labels = None #@param\n", + "\n", + "#@markdown Classifier traning params. These should not require tuning.\n", + "learning_rate = 1e-3 #@param\n", + "weak_neg_weight = 0.05 #@param\n", + "l2_mu = 0.000 #@param\n", + "num_steps = 128 #@param\n", + "\n", + "train_ratio = 0.01 #@param\n", + "batch_size = 128 #@param\n", + "weak_negatives_batch_size = 128 #@param\n", + "loss_fn_name = 'bce' #@param ['hinge', 'bce']\n", + "\n", + "data_manager = classifier_data.AgileDataManager(\n", + " target_labels=target_labels,\n", + " db=db,\n", + " train_ratio=0.9,\n", + " min_eval_examples=1,\n", + " batch_size=batch_size,\n", + " weak_negatives_batch_size=weak_negatives_batch_size,\n", + " rng=np.random.default_rng(seed=5))\n", + "print('Training for target labels : ')\n", + "print(data_manager.get_target_labels())\n", + "params, eval_scores = classifier.train_linear_classifier(\n", + " data_manager=data_manager,\n", + " learning_rate=learning_rate,\n", + " weak_neg_weight=weak_neg_weight,\n", + " l2_mu=l2_mu,\n", + " num_train_steps=num_steps,\n", + " loss_name=loss_fn_name,\n", + ")\n", + "print('\\n' + '-' * 80)\n", + "top1 = eval_scores['top1_acc']\n", + "print(f'top-1 {top1:.3f}')\n", + "rocauc = eval_scores['roc_auc']\n", + "print(f'roc_auc {rocauc:.3f}')\n", + "cmap = eval_scores['cmap']\n", + "print(f'cmap {cmap:.3f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a3N6dzhetkG1" + }, + "outputs": [], + "source": [ + "#@title Review Classifier Results. { vertical-output: true }\n", + "\n", + "#@markdown Number of results to find and display.\n", + "target_label = 'easwhi1' #@param {type:'string'}\n", + "num_results = 50 #@param\n", + "\n", + "target_label_idx = data_manager.get_target_labels().index(target_label)\n", + "class_query = params['beta'][:, target_label_idx]\n", + "bias = params['beta_bias'][target_label_idx]\n", + "\n", + "#@markdown Number of (randomly selected) database entries to search over.\n", + "sample_size = 1_000_000 #@param\n", + "\n", + "#@markdown Whether to use margin-sampling. If checked, search for examples\n", + "#@markdown with logits near a particular target score (usually 0).\n", + "margin_sampling = False #@param {type: 'boolean'}\n", + "\n", + "#@markdown When margin sampling, target this logit.\n", + "margin_target_score = -0.0 #@param\n", + "if not margin_sampling:\n", + " margin_target_score = None\n", + "score_fn = score_functions.get_score_fn(\n", + " 'dot', bias=bias, target_score=margin_target_score)\n", + "results, all_scores = brutalism.threaded_brute_search(\n", + " db, class_query, num_results, score_fn=score_fn,\n", + " sample_size=sample_size)\n", + "\n", + "# TODO(tomdenton): Better histogram when margin sampling.\n", + "_ = plt.hist(all_scores, bins=100)\n", + "hit_scores = [r.sort_score for r in results.search_results]\n", + "plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n", + " color='r', alpha=0.5)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EiNoGhyoDF2v" + }, + "outputs": [], + "source": [ + "#@title Display Results. { vertical-output: true }\n", + "\n", + "display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(\n", + " results, db, sample_rate_hz=32000, frame_rate=100,\n", + " audio_loader=audio_filepath_loader)\n", + "display_results.display(positive_labels=[target_label])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kEk15jw_B8xL" + }, + "outputs": [], + "source": [ + "#@title Save data labels. { vertical-output: true }\n", + "\n", + "prev_lbls, new_lbls = 0, 0\n", + "for lbl in display_results.harvest_labels(annotator_id):\n", + " check = db.insert_label(lbl, skip_duplicates=True)\n", + " new_lbls += check\n", + " prev_lbls += (1 - check)\n", + "print('\\nnew_lbls: ', new_lbls)\n", + "print('\\nprev_lbls: ', prev_lbls)" + ] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "private_outputs": true, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/chirp/projects/agile2/audio_loader.py b/chirp/projects/agile2/audio_loader.py new file mode 100644 index 00000000..23c051e6 --- /dev/null +++ b/chirp/projects/agile2/audio_loader.py @@ -0,0 +1,69 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio loader function helpers.""" + +import os +from typing import Callable + +from chirp import audio_utils +from etils import epath +import numpy as np + + +def make_filepath_loader( + audio_globs: dict[str, tuple[str, str]], + sample_rate_hz: int = 32000, + window_size_s: float = 5.0, + dtype: str = 'float32', +) -> Callable[[str, float], np.ndarray]: + """Create a function for loading audio from a source ID and offset. + + Note that if multiple globs match a given source ID, the first match is used. + + Args: + audio_globs: Mapping from dataset name to pairs of `(root directory, file + glob)`. (See `embed.EmbedConfig` for details.) + sample_rate_hz: Sample rate of the audio. + window_size_s: Window size of the audio. + dtype: Data type of the audio. + + Returns: + Function for loading audio from a source ID and offset. + + Raises: + ValueError if no audio path is found for the given source ID. + """ + + def loader(source_id: str, offset_s: float) -> np.ndarray: + found_path = None + for base_path, _ in audio_globs.values(): + path = epath.Path(base_path) / source_id + if path.exists(): + found_path = path + break + if found_path is None: + raise ValueError('No audio path found for source_id: ', source_id) + return np.array( + audio_utils.load_audio_window( + found_path.as_posix(), + offset_s, + sample_rate=sample_rate_hz, + window_size_s=window_size_s, + ), + dtype=dtype, + ) + + return loader diff --git a/chirp/projects/agile2/classifier.py b/chirp/projects/agile2/classifier.py index 311d3f18..e6250cd8 100644 --- a/chirp/projects/agile2/classifier.py +++ b/chirp/projects/agile2/classifier.py @@ -124,7 +124,7 @@ def train_linear_classifier( """Train a linear classifier.""" optimizer = optax.adam(learning_rate=learning_rate) embedding_dim = data_manager.db.embedding_dimension() - num_classes = len(data_manager.target_labels) + num_classes = len(data_manager.get_target_labels()) params = { 'beta': jnp.zeros((embedding_dim, num_classes)), 'beta_bias': jnp.zeros((num_classes,)), diff --git a/chirp/projects/agile2/classifier_data.py b/chirp/projects/agile2/classifier_data.py index 0558f301..dd9f35cf 100644 --- a/chirp/projects/agile2/classifier_data.py +++ b/chirp/projects/agile2/classifier_data.py @@ -15,7 +15,6 @@ """Tools for processing data for the Agile2 classifier.""" -import abc import dataclasses import itertools from typing import Any, Iterator, Sequence @@ -65,7 +64,7 @@ def join_batches(self, other: 'LabeledExample') -> 'LabeledExample': class DataManager: """Base class for managing data for training and evaluation.""" - target_labels: tuple[str, ...] + target_labels: tuple[str, ...] | None db: interface.GraphSearchDBInterface batch_size: int rng: np.random.Generator @@ -78,12 +77,18 @@ def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: """ raise NotImplementedError('get_train_test_split is not implemented.') + def get_target_labels(self) -> tuple[str, ...]: + if self.target_labels is None: + return tuple(self.db.get_classes()) + return self.target_labels + def get_multihot_labels(self, idx: int) -> tuple[np.ndarray, np.ndarray]: """Create the multihot label for one example.""" labels = self.db.get_labels(idx) - lbl_idxes = {label: i for i, label in enumerate(self.target_labels)} - pos = np.zeros(len(self.target_labels), dtype=np.float32) - neg = np.zeros(len(self.target_labels), dtype=np.float32) + target_labels = self.get_target_labels() + lbl_idxes = {label: i for i, label in enumerate(target_labels)} + pos = np.zeros(len(target_labels), dtype=np.float32) + neg = np.zeros(len(target_labels), dtype=np.float32) for label in labels: if label.type == interface.LabelType.POSITIVE: pos[lbl_idxes[label.label]] += 1.0 @@ -203,7 +208,7 @@ def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: Two numpy arrays contianing train and eval embedding ids, respectively. """ train_ids, eval_ids = [], [] - for label in self.target_labels: + for label in self.get_target_labels(): lbl_train, lbl_eval = self.get_single_label_train_test_split(label) train_ids.append(lbl_train) eval_ids.append(lbl_eval) @@ -257,7 +262,7 @@ def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: """Create a train/test split over the fully-annotated dataset.""" pos_id_sets = {} eval_id_sets = {} - for label in self.target_labels: + for label in self.get_target_labels(): pos_id_sets[label] = self.db.get_embeddings_by_label( label, interface.LabelType.POSITIVE, None ) @@ -268,7 +273,7 @@ def get_train_test_split(self) -> tuple[np.ndarray, np.ndarray]: # Now produce train sets of the desired size, # avoiding the selected eval examples. train_id_sets = {} - for label in self.target_labels: + for label in self.get_target_labels(): pos_set = np.setdiff1d(pos_id_sets[label], all_eval_ids) train_id_sets[label] = pos_set[: self.train_examples_per_class] if self.add_unlabeled_train_examples: diff --git a/chirp/projects/agile2/tests/classifier_data_test.py b/chirp/projects/agile2/tests/classifier_data_test.py index bdd6c9e2..78f59afb 100644 --- a/chirp/projects/agile2/tests/classifier_data_test.py +++ b/chirp/projects/agile2/tests/classifier_data_test.py @@ -156,6 +156,29 @@ def test_partial_classes(self): self.assertEqual(eval_count, 0) self.assertEqual(train_count, 0) + def test_auto_labels(self): + num_embeddings = 5_000 + db = self.make_db_with_labels( + num_embeddings=num_embeddings, + # Add a label to every embedding. + unlabeled_prob=0.0, + positive_label_prob=0.5, + rng=np.random.default_rng(42), + ) + # Only use three labels, which is half the length of the full class list. + data_manager = classifier_data.AgileDataManager( + target_labels=None, + db=db, + train_ratio=0.8, + min_eval_examples=10, + batch_size=10, + weak_negatives_batch_size=10, + rng=np.random.default_rng(42), + ) + self.assertLen( + data_manager.get_target_labels(), len(test_utils.CLASS_LABELS) + ) + def test_multihot_labels(self): db = test_utils.make_db( self.tempdir, diff --git a/chirp/projects/agile2/tests/embedding_display_test.py b/chirp/projects/agile2/tests/embedding_display_test.py index a58f0ea0..4d7aa190 100644 --- a/chirp/projects/agile2/tests/embedding_display_test.py +++ b/chirp/projects/agile2/tests/embedding_display_test.py @@ -20,7 +20,7 @@ import tempfile from unittest import mock -from chirp import audio_utils +from chirp.projects.agile2 import audio_loader from chirp.projects.agile2 import embedding_display from chirp.projects.agile2.tests import test_utils from chirp.projects.hoplite import interface @@ -156,7 +156,7 @@ def test_embedding_display_group(self): member0 = embedding_display.EmbeddingDisplay( embedding_id=123, dataset_name='test_dataset', - uri=os.path.join(self.tempdir, 'pos', 'foo_pos.wav'), + uri=os.path.join('pos', 'foo_pos.wav'), offset_s=1.0, score=0.5, sample_rate_hz=sample_rate_hz, @@ -164,26 +164,30 @@ def test_embedding_display_group(self): member1 = embedding_display.EmbeddingDisplay( embedding_id=456, dataset_name='test_dataset', - uri=os.path.join(self.tempdir, 'neg', 'bar_neg.wav'), + uri=os.path.join('neg', 'bar_neg.wav'), offset_s=2.0, score=0.6, sample_rate_hz=sample_rate_hz, ) - audio_loader = lambda uri, offset_s: audio_utils.load_audio_window( - uri, offset_s=offset_s, sample_rate=sample_rate_hz, window_size_s=3.0 + audio_globs = { + 'test_dataset': (self.tempdir, '*/*.wav'), + } + filepath_audio_loader = audio_loader.make_filepath_loader( + audio_globs, sample_rate_hz=sample_rate_hz, window_size_s=1.0 ) group = embedding_display.EmbeddingDisplayGroup.create( members=[member0, member1], sample_rate_hz=sample_rate_hz, - audio_loader=audio_loader, + audio_loader=filepath_audio_loader, ) self.assertLen(group.members, 2) self.assertEqual(group.members[0].embedding_id, 123) self.assertEqual(group.members[1].embedding_id, 456) - self.assertEqual(group.members[0].dataset_name, 'test_dataset') for got_member in group.iterator_with_audio(): + self.assertEqual(got_member.dataset_name, 'test_dataset') self.assertIsNotNone(got_member.audio) + self.assertEqual(got_member.audio.shape[0], 16000) self.assertIsNotNone(got_member.spectrogram) diff --git a/chirp/projects/hoplite/score_functions.py b/chirp/projects/hoplite/score_functions.py index 0d1f562a..acd5f2e0 100644 --- a/chirp/projects/hoplite/score_functions.py +++ b/chirp/projects/hoplite/score_functions.py @@ -15,9 +15,36 @@ """Score functions for Hoplite.""" +from typing import Callable import numpy as np +def get_score_fn( + name: str, bias: float | None = None, target_score: float | None = None +) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: + """Get a score function by name.""" + if name == 'dot': + score_fn = numpy_dot + elif name == 'cos': + score_fn = numpy_cos + elif name == 'euclidean': + score_fn = numpy_euclidean + else: + raise ValueError('Unknown score function: ', name) + + if bias is not None: + bias_fn = lambda x, y: score_fn(x, y) + bias + else: + bias_fn = score_fn + + if target_score is not None: + targeted_fn = lambda x, y: np.abs(bias_fn(x, y) - target_score) + else: + targeted_fn = bias_fn + + return targeted_fn + + def numpy_dot(data: np.ndarray, query: np.ndarray) -> np.ndarray: """Simple numpy dot product, which allows for multiple queries.""" if len(query.shape) > 1: @@ -27,7 +54,7 @@ def numpy_dot(data: np.ndarray, query: np.ndarray) -> np.ndarray: return np.dot(data, query) -def numpy_cos(data, query): +def numpy_cos(data: np.ndarray, query: np.ndarray) -> np.ndarray: """Simple numpy cosine similarity, allowing multiple queries.""" data_norms = np.linalg.norm(data, axis=1) query_norms = np.linalg.norm(query, axis=-1) @@ -40,7 +67,7 @@ def numpy_cos(data, query): return np.dot(unit_data, unit_query) -def numpy_euclidean(data, query): +def numpy_euclidean(data: np.ndarray, query: np.ndarray) -> np.ndarray: """Numpy L2 distance allowing multiple queries.""" data_norms = np.linalg.norm(data, axis=-1) if len(query.shape) > 1: