Skip to content

Commit

Permalink
Agile2 agile modeling notebook.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680052229
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 1, 2024
1 parent 6a836b8 commit b85cc65
Show file tree
Hide file tree
Showing 7 changed files with 460 additions and 18 deletions.
314 changes: 314 additions & 0 deletions chirp/projects/agile2/2_agile_modeling_v2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AXQAcreKedWU"
},
"outputs": [],
"source": [
"#@title Imports. { vertical-output: true }\n",
"\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from chirp.projects.agile2 import audio_loader\n",
"from chirp.projects.agile2 import classifier\n",
"from chirp.projects.agile2 import classifier_data\n",
"from chirp.projects.agile2 import embedding_display\n",
"from chirp.projects.hoplite import brutalism\n",
"from chirp.projects.hoplite import score_functions\n",
"from chirp.projects.hoplite import sqlite_impl\n",
"from chirp.projects.zoo import models\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fz63sWKEedWU"
},
"outputs": [],
"source": [
"#@title Load model and connect to database. { vertical-output: true }\n",
"\n",
"#@markdown Location of database containing audio embeddings.\n",
"db_path = '' #@param {type:'string'}\n",
"#@markdown Identifier (eg, name) to attach to labels produced during validation.\n",
"annotator_id = 'linnaeus' #@param {type:'string'}\n",
"\n",
"db = sqlite_impl.SQLiteGraphSearchDB.create(db_path)\n",
"db_model_config = db.get_metadata('model_config')\n",
"embed_config = db.get_metadata('embed_config')\n",
"model_class = models.model_class_map()[db_model_config.model_key]\n",
"embedding_model = model_class.from_config(db_model_config.model_config)\n",
"\n",
"audio_filepath_loader = audio_loader.make_filepath_loader(\n",
" audio_globs=embed_config.audio_globs,\n",
" window_size_s=embedding_model.window_size_s,\n",
" sample_rate_hz=embedding_model.sample_rate,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BVdLJJd9gnjo"
},
"source": [
"# Search"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7ig3L5dsy3mr"
},
"outputs": [],
"source": [
"#@title Load query audio. { vertical-output: true }\n",
"\n",
"#@markdown The `query_uri` can be a URL, filepath, or Xeno-Canto ID\n",
"#@markdown (like `xc777802`, containing an Eastern Whipbird (`easwhi1`)).\n",
"query_uri = 'xc777802' #@param {type:'string'}\n",
"query_label = 'easwhi1' #@param {type:'string'}\n",
"\n",
"query = embedding_display.QueryDisplay(\n",
" uri=query_uri, offset_s=0.0, window_size_s=5.0, sample_rate_hz=32000)\n",
"_ = query.display_interactive()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iHUJ_NwQWZNB"
},
"outputs": [],
"source": [
"#@title Embed the Query and Search. { vertical-output: true }\n",
"\n",
"#@markdown Number of results to find and display.\n",
"num_results = 50 #@param\n",
"query_embedding = embedding_model.embed(\n",
" query.get_audio_window()).embeddings[0, 0]\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
"\n",
"#@markdown If checked, search for examples\n",
"#@markdown near a particular target score.\n",
"target_sampling = False #@param {type: 'boolean'}\n",
"\n",
"#@markdown When target sampling, target this score.\n",
"target_score = -1.0 #@param\n",
"\n",
"if not target_sampling:\n",
" target_score = None\n",
"score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n",
"results, all_scores = brutalism.threaded_brute_search(\n",
" db, query_embedding, num_results, score_fn=score_fn,\n",
" sample_size=sample_size)\n",
"\n",
"# TODO(tomdenton): Better histogram when target sampling.\n",
"_ = plt.hist(all_scores, bins=100)\n",
"hit_scores = [r.sort_score for r in results.search_results]\n",
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y21fWjEwXj68"
},
"outputs": [],
"source": [
"#@title Display Results. { vertical-output: true }\n",
"\n",
"display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(\n",
" results, db, sample_rate_hz=32000, frame_rate=100,\n",
" audio_loader=audio_filepath_loader)\n",
"display_results.display(positive_labels=[query_label])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G3sIkOqlXzKB"
},
"outputs": [],
"source": [
"#@title Save data labels. { vertical-output: true }\n",
"\n",
"prev_lbls, new_lbls = 0, 0\n",
"for lbl in display_results.harvest_labels(annotator_id):\n",
" check = db.insert_label(lbl, skip_duplicates=True)\n",
" new_lbls += check\n",
" prev_lbls += (1 - check)\n",
"print('\\nnew_lbls: ', new_lbls)\n",
"print('\\nprev_lbls: ', prev_lbls)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o65wpjvyYft-"
},
"source": [
"# Classify"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qtsJkgcPYg6z"
},
"outputs": [],
"source": [
"#@title Classifier training. { vertical-output: true }\n",
"\n",
"#@markdown Set of labels to classify. If None, auto-populated from the DB.\n",
"target_labels = None #@param\n",
"\n",
"#@markdown Classifier traning params. These should not require tuning.\n",
"learning_rate = 1e-3 #@param\n",
"weak_neg_weight = 0.05 #@param\n",
"l2_mu = 0.000 #@param\n",
"num_steps = 128 #@param\n",
"\n",
"train_ratio = 0.01 #@param\n",
"batch_size = 128 #@param\n",
"weak_negatives_batch_size = 128 #@param\n",
"loss_fn_name = 'bce' #@param ['hinge', 'bce']\n",
"\n",
"data_manager = classifier_data.AgileDataManager(\n",
" target_labels=target_labels,\n",
" db=db,\n",
" train_ratio=0.9,\n",
" min_eval_examples=1,\n",
" batch_size=batch_size,\n",
" weak_negatives_batch_size=weak_negatives_batch_size,\n",
" rng=np.random.default_rng(seed=5))\n",
"print('Training for target labels : ')\n",
"print(data_manager.get_target_labels())\n",
"params, eval_scores = classifier.train_linear_classifier(\n",
" data_manager=data_manager,\n",
" learning_rate=learning_rate,\n",
" weak_neg_weight=weak_neg_weight,\n",
" l2_mu=l2_mu,\n",
" num_train_steps=num_steps,\n",
" loss_name=loss_fn_name,\n",
")\n",
"print('\\n' + '-' * 80)\n",
"top1 = eval_scores['top1_acc']\n",
"print(f'top-1 {top1:.3f}')\n",
"rocauc = eval_scores['roc_auc']\n",
"print(f'roc_auc {rocauc:.3f}')\n",
"cmap = eval_scores['cmap']\n",
"print(f'cmap {cmap:.3f}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a3N6dzhetkG1"
},
"outputs": [],
"source": [
"#@title Review Classifier Results. { vertical-output: true }\n",
"\n",
"#@markdown Number of results to find and display.\n",
"target_label = 'easwhi1' #@param {type:'string'}\n",
"num_results = 50 #@param\n",
"\n",
"target_label_idx = data_manager.get_target_labels().index(target_label)\n",
"class_query = params['beta'][:, target_label_idx]\n",
"bias = params['beta_bias'][target_label_idx]\n",
"\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
"\n",
"#@markdown Whether to use margin-sampling. If checked, search for examples\n",
"#@markdown with logits near a particular target score (usually 0).\n",
"margin_sampling = False #@param {type: 'boolean'}\n",
"\n",
"#@markdown When margin sampling, target this logit.\n",
"margin_target_score = -0.0 #@param\n",
"if not margin_sampling:\n",
" margin_target_score = None\n",
"score_fn = score_functions.get_score_fn(\n",
" 'dot', bias=bias, target_score=margin_target_score)\n",
"results, all_scores = brutalism.threaded_brute_search(\n",
" db, class_query, num_results, score_fn=score_fn,\n",
" sample_size=sample_size)\n",
"\n",
"# TODO(tomdenton): Better histogram when margin sampling.\n",
"_ = plt.hist(all_scores, bins=100)\n",
"hit_scores = [r.sort_score for r in results.search_results]\n",
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EiNoGhyoDF2v"
},
"outputs": [],
"source": [
"#@title Display Results. { vertical-output: true }\n",
"\n",
"display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(\n",
" results, db, sample_rate_hz=32000, frame_rate=100,\n",
" audio_loader=audio_filepath_loader)\n",
"display_results.display(positive_labels=[target_label])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kEk15jw_B8xL"
},
"outputs": [],
"source": [
"#@title Save data labels. { vertical-output: true }\n",
"\n",
"prev_lbls, new_lbls = 0, 0\n",
"for lbl in display_results.harvest_labels(annotator_id):\n",
" check = db.insert_label(lbl, skip_duplicates=True)\n",
" new_lbls += check\n",
" prev_lbls += (1 - check)\n",
"print('\\nnew_lbls: ', new_lbls)\n",
"print('\\nprev_lbls: ', prev_lbls)"
]
}
],
"metadata": {
"colab": {
"last_runtime": {
"build_target": "",
"kind": "local"
},
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
69 changes: 69 additions & 0 deletions chirp/projects/agile2/audio_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Audio loader function helpers."""

import os
from typing import Callable

from chirp import audio_utils
from etils import epath
import numpy as np


def make_filepath_loader(
audio_globs: dict[str, tuple[str, str]],
sample_rate_hz: int = 32000,
window_size_s: float = 5.0,
dtype: str = 'float32',
) -> Callable[[str, float], np.ndarray]:
"""Create a function for loading audio from a source ID and offset.
Note that if multiple globs match a given source ID, the first match is used.
Args:
audio_globs: Mapping from dataset name to pairs of `(root directory, file
glob)`. (See `embed.EmbedConfig` for details.)
sample_rate_hz: Sample rate of the audio.
window_size_s: Window size of the audio.
dtype: Data type of the audio.
Returns:
Function for loading audio from a source ID and offset.
Raises:
ValueError if no audio path is found for the given source ID.
"""

def loader(source_id: str, offset_s: float) -> np.ndarray:
found_path = None
for base_path, _ in audio_globs.values():
path = epath.Path(base_path) / source_id
if path.exists():
found_path = path
break
if found_path is None:
raise ValueError('No audio path found for source_id: ', source_id)
return np.array(
audio_utils.load_audio_window(
found_path.as_posix(),
offset_s,
sample_rate=sample_rate_hz,
window_size_s=window_size_s,
),
dtype=dtype,
)

return loader
2 changes: 1 addition & 1 deletion chirp/projects/agile2/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def train_linear_classifier(
"""Train a linear classifier."""
optimizer = optax.adam(learning_rate=learning_rate)
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.target_labels)
num_classes = len(data_manager.get_target_labels())
params = {
'beta': jnp.zeros((embedding_dim, num_classes)),
'beta_bias': jnp.zeros((num_classes,)),
Expand Down
Loading

0 comments on commit b85cc65

Please sign in to comment.