-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
403 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,326 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "AXQAcreKedWU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Imports. { vertical-output: true }\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"from matplotlib import pyplot as plt\n", | ||
"\n", | ||
"from chirp.projects.agile2 import audio_loader\n", | ||
"from chirp.projects.agile2 import classifier\n", | ||
"from chirp.projects.agile2 import classifier_data\n", | ||
"from chirp.projects.agile2 import embedding_display\n", | ||
"from chirp.projects.hoplite import brutalism\n", | ||
"from chirp.projects.hoplite import sqlite_impl\n", | ||
"from chirp.projects.zoo import models\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "fz63sWKEedWU" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Load model and connect to database. { vertical-output: true }\n", | ||
"\n", | ||
"#@markdown Location of database containing audio embeddings.\n", | ||
"db_path = '' #@param {type:'string'}\n", | ||
"\n", | ||
"db = sqlite_impl.SQLiteGraphSearchDB.create(db_path)\n", | ||
"db_model_config = db.get_metadata('model_config')\n", | ||
"embed_config = db.get_metadata('embed_config')\n", | ||
"model_class = models.model_class_map()[db_model_config.model_key]\n", | ||
"embedding_model = model_class.from_config(db_model_config.model_config)\n", | ||
"\n", | ||
"audio_filepath_loader = audio_loader.make_filepath_loader(\n", | ||
" audio_globs=embed_config.audio_globs,\n", | ||
" window_size_s=embedding_model.window_size_s,\n", | ||
" sample_rate_hz=embedding_model.sample_rate,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "BVdLJJd9gnjo" | ||
}, | ||
"source": [ | ||
"# Search" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "7ig3L5dsy3mr" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Load query audio. { vertical-output: true }\n", | ||
"\n", | ||
"#@markdown The `query_uri` can be a URL, filepath, or Xeno-Canto ID\n", | ||
"#@markdown (like `xc777802`, containing an Eastern Whipbird (`easwhi1`)).\n", | ||
"query_uri = 'xc777802' #@param {type:'string'}\n", | ||
"query_label = 'easwhi1' #@param {type:'string'}\n", | ||
"\n", | ||
"query = embedding_display.QueryDisplay(\n", | ||
" uri=query_uri, offset_s=0.0, window_size_s=5.0, sample_rate_hz=32000)\n", | ||
"_ = query.display_interactive()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "iHUJ_NwQWZNB" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Embed the Query and Search. { vertical-output: true }\n", | ||
"\n", | ||
"#@markdown Number of results to find and display.\n", | ||
"num_results = 50 #@param\n", | ||
"query_embedding = embedding_model.embed(\n", | ||
" query.get_audio_window()).embeddings[0, 0]\n", | ||
"#@markdown Number of (randomly selected) database entries to search over.\n", | ||
"sample_size = 1_000_000 #@param\n", | ||
"\n", | ||
"#@markdown If checked, search for examples\n", | ||
"#@markdown near a particular target score.\n", | ||
"target_sampling = False #@param {type: 'boolean'}\n", | ||
"\n", | ||
"#@markdown When target sampling, target this score.\n", | ||
"target_score = -1.0 #@param\n", | ||
"if target_sampling:\n", | ||
" score_fn = lambda x, y: -np.abs(np.dot(x, y) - target_score)\n", | ||
"else:\n", | ||
" score_fn = np.dot\n", | ||
"\n", | ||
"if target_score is None:\n", | ||
" score_fn = np.dot\n", | ||
"else:\n", | ||
" score_fn = lambda x, y: -np.abs(np.dot(x, y) - target_score)\n", | ||
"results, all_scores = brutalism.threaded_brute_search(\n", | ||
" db, query_embedding, num_results, score_fn=score_fn)\n", | ||
"\n", | ||
"# TODO(tomdenton): Better histogram when target sampling.\n", | ||
"_ = plt.hist(all_scores, bins=100)\n", | ||
"hit_scores = [r.sort_score for r in results.search_results]\n", | ||
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n", | ||
" color='r', alpha=0.5)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "Y21fWjEwXj68" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Display Results. { vertical-output: true }\n", | ||
"\n", | ||
"display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(\n", | ||
" results, db, sample_rate_hz=32000, frame_rate=100,\n", | ||
" audio_loader=audio_filepath_loader)\n", | ||
"display_results.display(positive_labels=[query_label])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "G3sIkOqlXzKB" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Save data labels. { vertical-output: true }\n", | ||
"\n", | ||
"prev_lbls, new_lbls = 0, 0\n", | ||
"for lbl in display_results.harvest_labels('tmd'):\n", | ||
" emb_lbls = db.get_labels(lbl.embedding_id)\n", | ||
" # Check for existince of matching label...\n", | ||
" for emb_lbl in emb_lbls:\n", | ||
" if emb_lbl.label == lbl.label and emb_lbl.provenance == lbl.provenance:\n", | ||
" prev_lbls += 1\n", | ||
" break\n", | ||
" else:\n", | ||
" db.insert_label(lbl)\n", | ||
" new_lbls += 1\n", | ||
"print('\\nnew_lbls: ', new_lbls)\n", | ||
"print('\\nprev_lbls: ', prev_lbls)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "o65wpjvyYft-" | ||
}, | ||
"source": [ | ||
"# Classify" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "qtsJkgcPYg6z" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Classifier training. { vertical-output: true }\n", | ||
"\n", | ||
"#@markdown Set of labels to classify. If None, auto-populated from the DB.\n", | ||
"target_labels = None #@param\n", | ||
"\n", | ||
"#@markdown Classifier traning params. These should not require tuning.\n", | ||
"learning_rate = 1e-3 #@param\n", | ||
"weak_neg_weight = 1.0 #@param\n", | ||
"l2_mu = 0.000 #@param\n", | ||
"num_steps = 128 #@param\n", | ||
"\n", | ||
"train_ratio = 0.01 #@param\n", | ||
"batch_size = 128 #@param\n", | ||
"weak_negatives_batch_size = 128 #@param\n", | ||
"loss_fn_name = 'bce' #@param ['hinge', 'bce']\n", | ||
"\n", | ||
"data_manager = classifier_data.AgileDataManager(\n", | ||
" target_labels=target_labels,\n", | ||
" db=db,\n", | ||
" train_ratio=0.9,\n", | ||
" min_eval_examples=1,\n", | ||
" batch_size=batch_size,\n", | ||
" weak_negatives_batch_size=weak_negatives_batch_size,\n", | ||
" rng=np.random.default_rng(seed=5))\n", | ||
"print('Training for target labels : ')\n", | ||
"print(data_manager.get_target_labels())\n", | ||
"params, eval_scores = classifier.train_linear_classifier(\n", | ||
" data_manager=data_manager,\n", | ||
" learning_rate=learning_rate,\n", | ||
" weak_neg_weight=weak_neg_weight,\n", | ||
" l2_mu=l2_mu,\n", | ||
" num_train_steps=num_steps,\n", | ||
" loss_name=loss_fn_name,\n", | ||
")\n", | ||
"print('\\n' + '-' * 80)\n", | ||
"top1 = eval_scores['top1_acc']\n", | ||
"print(f'top-1 {top1:.3f}')\n", | ||
"rocauc = eval_scores['roc_auc']\n", | ||
"print(f'roc_auc {rocauc:.3f}')\n", | ||
"cmap = eval_scores['cmap']\n", | ||
"print(f'cmap {cmap:.3f}')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "a3N6dzhetkG1" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Review Classifier Results. { vertical-output: true }\n", | ||
"\n", | ||
"#@markdown Number of results to find and display.\n", | ||
"target_label = 'LEPFUS' #@param {type:'string'}\n", | ||
"num_results = 50 #@param\n", | ||
"\n", | ||
"target_label_idx = data_manager.get_target_labels().index(target_label)\n", | ||
"class_query = params['beta'][:, target_label_idx]\n", | ||
"\n", | ||
"#@markdown Number of (randomly selected) database entries to search over.\n", | ||
"sample_size = 1_000_000 #@param\n", | ||
"\n", | ||
"#@markdown Whether to use margin-sampling. If checked, search for examples\n", | ||
"#@markdown with logits near a particular target score (usually 0).\n", | ||
"margin_sampling = False #@param {type: 'boolean'}\n", | ||
"\n", | ||
"#@markdown When margin sampling, target this logit.\n", | ||
"margin_target_score = -1.0 #@param\n", | ||
"if margin_sampling:\n", | ||
" score_fn = lambda x, y: -np.abs(np.dot(x, y) - margin_target_score)\n", | ||
"else:\n", | ||
" score_fn = np.dot\n", | ||
"results, all_scores = brutalism.threaded_brute_search(\n", | ||
" db, class_query, num_results, score_fn=score_fn,\n", | ||
" sample_size=sample_size)\n", | ||
"\n", | ||
"# TODO(tomdenton): Better histogram when margin sampling.\n", | ||
"_ = plt.hist(all_scores, bins=100)\n", | ||
"hit_scores = [r.sort_score for r in results.search_results]\n", | ||
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n", | ||
" color='r', alpha=0.5)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "EiNoGhyoDF2v" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Display Results. { vertical-output: true }\n", | ||
"\n", | ||
"display_results = embedding_display.EmbeddingDisplayGroup.from_search_results(\n", | ||
" results, db, sample_rate_hz=32000, frame_rate=100,\n", | ||
" audio_loader=audio_filepath_loader)\n", | ||
"display_results.display(positive_labels=[target_label])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "kEk15jw_B8xL" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#@title Save data labels. { vertical-output: true }\n", | ||
"\n", | ||
"prev_lbls, new_lbls = 0, 0\n", | ||
"for lbl in display_results.harvest_labels('tmd'):\n", | ||
" emb_lbls = db.get_labels(lbl.embedding_id)\n", | ||
" # Check for existince of matching label...\n", | ||
" for emb_lbl in emb_lbls:\n", | ||
" if emb_lbl.label == lbl.label and emb_lbl.provenance == lbl.provenance:\n", | ||
" prev_lbls += 1\n", | ||
" break\n", | ||
" else:\n", | ||
" db.insert_label(lbl)\n", | ||
" new_lbls += 1\n", | ||
"print('\\nnew_lbls: ', new_lbls)\n", | ||
"print('\\nprev_lbls: ', prev_lbls)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"last_runtime": { | ||
"build_target": "", | ||
"kind": "local" | ||
}, | ||
"private_outputs": true, | ||
"provenance": [], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Perch Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Audio loader function helpers.""" | ||
|
||
import os | ||
from typing import Callable | ||
|
||
from chirp import audio_utils | ||
import numpy as np | ||
|
||
|
||
def make_filepath_loader( | ||
audio_globs: dict[str, tuple[str, str]], | ||
sample_rate_hz: int = 32000, | ||
window_size_s: float = 5.0, | ||
dtype: str = 'float32', | ||
) -> Callable[[str, float], np.ndarray]: | ||
"""Create a function for loading audio from a source ID and offset. | ||
Args: | ||
audio_globs: Mapping from dataset name to pairs of `(root directory, file | ||
glob)`. (See `embed.EmbedConfig` for details.) | ||
sample_rate_hz: Sample rate of the audio. | ||
window_size_s: Window size of the audio. | ||
dtype: Data type of the audio. | ||
Returns: | ||
Function for loading audio from a source ID and offset. | ||
""" | ||
|
||
def loader(source_id: str, offset_s: float) -> np.ndarray: | ||
found_path = None | ||
for _, (base_path, _) in audio_globs.items(): | ||
path = os.path.join(base_path, source_id) | ||
if os.path.exists(path): | ||
found_path = path | ||
break | ||
if found_path is None: | ||
raise ValueError('No audio path found for source_id: ', source_id) | ||
return np.array( | ||
audio_utils.load_audio_window( | ||
found_path, | ||
offset_s, | ||
sample_rate=sample_rate_hz, | ||
window_size_s=window_size_s, | ||
), | ||
dtype=dtype, | ||
) | ||
|
||
return loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.