From 10bf1568500296a83a31bce9dd501616aed921c7 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Mon, 30 Sep 2024 09:51:14 -0700 Subject: [PATCH] Add presets for fully-annotated datasets and convenience method for creating an annotated embeddings database. PiperOrigin-RevId: 680614976 --- chirp/projects/agile2/ingest_annotations.py | 174 +++++++++++++++-- chirp/projects/agile2/source_info.py | 6 +- .../agile2/tests/classifier_data_test.py | 63 ------- .../agile2/tests/ingest_annotations_test.py | 175 ++++++++++++++++++ chirp/taxonomy/annotations.py | 29 ++- chirp/taxonomy/annotations_fns.py | 6 +- 6 files changed, 371 insertions(+), 82 deletions(-) create mode 100644 chirp/projects/agile2/tests/ingest_annotations_test.py diff --git a/chirp/projects/agile2/ingest_annotations.py b/chirp/projects/agile2/ingest_annotations.py index 95e1664d..6ab5ecfa 100644 --- a/chirp/projects/agile2/ingest_annotations.py +++ b/chirp/projects/agile2/ingest_annotations.py @@ -13,37 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Ingest fully-annotated dataset labels from CSV.""" +"""Ingest fully-annotated dataset audio and labels.""" +import collections import dataclasses -from typing import Callable, Sequence +from typing import Callable +from chirp.projects.agile2 import embed +from chirp.projects.hoplite import db_loader from chirp.projects.hoplite import interface -from chirp.taxonomy import annotations +from chirp.taxonomy import annotations_fns from etils import epath +from ml_collections import config_dict import pandas as pd import tqdm +BASE_PATH = epath.Path('gs://chirp-public-bucket/soundscapes/') + @dataclasses.dataclass class AnnotatedDatasetIngestor: """Add annotations to embeddings DB from CSV annotations. Note that currently we only add positive labels. + + Attributes: + base_path: Base path for the dataset. + audio_glob: Glob for the audio files. + dataset_name: Name of the dataset. + annotation_filename: Filename for the annotations CSV. + annotation_load_fn: Function to load the annotations CSV. """ - base_path: str + base_path: epath.Path + audio_glob: str dataset_name: str annotation_filename: str annotation_load_fn: Callable[[str | epath.Path], pd.DataFrame] - window_size_s: float - provenance: str = 'dataset' - def ingest_dataset(self, db: interface.GraphSearchDBInterface) -> set[str]: - """Load annotations and insert labels into the DB.""" + def ingest_dataset( + self, + db: interface.GraphSearchDBInterface, + window_size_s: float, + provenance: str = 'annotations_csv', + ) -> collections.defaultdict[str, int]: + """Load annotations and insert labels into the DB. + + Args: + db: The DB to insert labels into. + window_size_s: The window size of the embeddings. + provenance: The provenance to use for the labels. + + Returns: + A dictionary of ingested label counts. + """ annos_path = epath.Path(self.base_path) / self.annotation_filename annos_df = self.annotation_load_fn(annos_path) - lbl_count = 0 + lbl_counts = collections.defaultdict(int) file_ids = annos_df['filename'].unique() label_set = set() for file_id in tqdm.tqdm(file_ids): @@ -52,7 +78,7 @@ def ingest_dataset(self, db: interface.GraphSearchDBInterface) -> set[str]: for idx in embedding_ids: source = db.get_embedding_source(idx) window_start = source.offsets[0] - window_end = window_start + self.window_size_s + window_end = window_start + window_size_s emb_annos = source_annos[source_annos['start_time_s'] < window_end] emb_annos = emb_annos[emb_annos['end_time_s'] > window_start] # All of the remianing annotations match the target embedding. @@ -60,9 +86,131 @@ def ingest_dataset(self, db: interface.GraphSearchDBInterface) -> set[str]: for label in labels: label_set.add(label) lbl = interface.Label( - idx, label, interface.LabelType.POSITIVE, self.provenance + idx, + label, + interface.LabelType.POSITIVE, + provenance=provenance, ) db.insert_label(lbl) - lbl_count += 1 + lbl_counts[label] += 1 + lbl_count = sum(lbl_counts.values()) print(f'\nInserted {lbl_count} labels.') - return label_set + return lbl_counts + + +CORNELL_LOADER = lambda x: annotations_fns.load_cornell_annotations( + x, file_id_prefix='audio/' +) + +PRESETS: dict[str, AnnotatedDatasetIngestor] = { + 'powdermill': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'powdermill', + audio_glob='*/*.wav', + dataset_name='powdermill', + annotation_filename='powdermill.csv', + annotation_load_fn=annotations_fns.load_powdermill_annotations, + ), + 'hawaii': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'hawaii', + dataset_name='hawaii', + audio_glob='audio/*.flac', + annotation_filename='annotations.csv', + annotation_load_fn=CORNELL_LOADER, + ), + 'high_sierras': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'high_sierras', + dataset_name='high_sierras', + audio_glob='audio/*.flac', + annotation_filename='annotations.csv', + annotation_load_fn=CORNELL_LOADER, + ), + 'coffee_farms': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'coffee_farms', + dataset_name='coffee_farms', + audio_glob='audio/*.flac', + annotation_filename='annotations.csv', + annotation_load_fn=CORNELL_LOADER, + ), + 'peru': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'peru', + dataset_name='peru', + audio_glob='audio/*.flac', + annotation_filename='annotations.csv', + annotation_load_fn=CORNELL_LOADER, + ), + 'ssw': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'ssw', + dataset_name='ssw', + audio_glob='audio/*.flac', + annotation_filename='annotations.csv', + annotation_load_fn=CORNELL_LOADER, + ), + 'sierras_kahl': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'sierras_kahl', + dataset_name='sierras_kahl', + audio_glob='audio/*.flac', + annotation_filename='annotations.csv', + annotation_load_fn=CORNELL_LOADER, + ), + 'anuraset': AnnotatedDatasetIngestor( + base_path=BASE_PATH / 'anuraset', + dataset_name='anuraset', + audio_glob='raw_data/*/*.wav', + annotation_filename='annotations.csv', + annotation_load_fn=annotations_fns.load_anuraset_annotations, + ), +} + + +def embed_annotated_dataset( + ds_choice: str | AnnotatedDatasetIngestor, + db_path: str, + db_model_config: embed.ModelConfig, +) -> tuple[interface.GraphSearchDBInterface, dict[str, int]]: + """Embed a fully-annotated dataset to SQLite Hoplite DB. + + Args: + ds_choice: The preset name of the dataset to embed. Alternatively, an + AnnotatedDatasetIngestor can be provided. + db_path: The path to the DB. + db_model_config: The model config for the DB. + + Returns: + The DB and a dictionary of label counts. + """ + if isinstance(ds_choice, str): + ingestor = PRESETS[ds_choice] + else: + ingestor = ds_choice + + db_filepath = f'{db_path}/hoplite_db.sqlite' + epath.Path(db_filepath).parent.mkdir(parents=True, exist_ok=True) + db_config = config_dict.ConfigDict({ + 'db_path': db_filepath, + 'embedding_dim': db_model_config.embedding_dim, + }) + db_config = db_loader.DBConfig('sqlite', db_config) + print(ingestor) + audio_srcs_config = embed.EmbedConfig( + audio_globs={ + ingestor.dataset_name: ( + ingestor.base_path.as_posix(), + ingestor.audio_glob, + ) + }, + min_audio_len_s=1.0, + ) + db = db_config.load_db() + db.setup() + print('Initialized DB located at ', db_filepath) + worker = embed.EmbedWorker( + embed_config=audio_srcs_config, db=db, model_config=db_model_config + ) + worker.process_all() + print(f'DB contains {db.count_embeddings()} embeddings.') + + class_counts = ingestor.ingest_dataset( + db, window_size_s=worker.embedding_model.window_size_s + ) + db.commit() + return db, class_counts diff --git a/chirp/projects/agile2/source_info.py b/chirp/projects/agile2/source_info.py index ecdc876b..d4bb9caf 100644 --- a/chirp/projects/agile2/source_info.py +++ b/chirp/projects/agile2/source_info.py @@ -71,9 +71,11 @@ def iterate_all_sources( shard_len_s (ie, the final shard). """ for dataset_name, (root_dir, file_glob) in self.audio_globs.items(): - filepaths = tuple(epath.Path(root_dir).glob(file_glob)) + # If root_dir is a URL, the posix path may not match the original string. + base_path = epath.Path(root_dir) + filepaths = tuple(base_path.glob(file_glob)) for filepath in tqdm.tqdm(filepaths): - file_id = filepath.as_posix()[len(root_dir) + 1 :] + file_id = filepath.as_posix()[len(base_path.as_posix()) + 1 :] if shard_len_s < 0: yield SourceId( dataset_name=dataset_name, diff --git a/chirp/projects/agile2/tests/classifier_data_test.py b/chirp/projects/agile2/tests/classifier_data_test.py index 23b2a3bb..bdd6c9e2 100644 --- a/chirp/projects/agile2/tests/classifier_data_test.py +++ b/chirp/projects/agile2/tests/classifier_data_test.py @@ -18,12 +18,9 @@ import shutil import tempfile -from chirp import path_utils from chirp.projects.agile2 import classifier_data -from chirp.projects.agile2 import ingest_annotations from chirp.projects.agile2.tests import test_utils from chirp.projects.hoplite import interface -from chirp.taxonomy import annotations_fns import numpy as np from absl.testing import absltest @@ -216,66 +213,6 @@ def test_multihot_labels(self): np.testing.assert_equal(multihot, (0, 1, 1, 0, 0, 1)) np.testing.assert_equal(mask, (0, 1, 1, 1, 0, 1)) - def test_ingest_annotations(self): - rng = np.random.default_rng(42) - db = test_utils.make_db( - self.tempdir, - 'in_mem', - num_embeddings=0, - rng=rng, - ) - emb_offsets = [175, 185, 275, 230, 235] - emb_idxes = [] - for offset in emb_offsets: - emb_idx = db.insert_embedding( - embedding=rng.normal([db.embedding_dimension()]), - source=interface.EmbeddingSource( - dataset_name='hawaii', - source_id='UHH_001_S01_20161121_150000.flac', - offsets=np.array([offset]), - ), - ) - emb_idxes.append(emb_idx) - - hawaii_annos_path = path_utils.get_absolute_path( - 'projects/agile2/tests/testdata/hawaii.csv' - ) - ingestor = ingest_annotations.AnnotatedDatasetIngestor( - base_path=hawaii_annos_path.parent, - dataset_name='hawaii', - annotation_filename='hawaii.csv', - annotation_load_fn=annotations_fns.load_cornell_annotations, - window_size_s=5.0, - provenance='test_dataset', - ) - inserted_labels = ingestor.ingest_dataset(db) - self.assertSetEqual(inserted_labels, {'jabwar', 'hawama', 'ercfra'}) - # Check that individual labels are correctly applied. - # The Hawai'i test data CSV contains a total of five annotations. - # The window at offset 175 should have no labels. - self.assertEmpty(db.get_labels(emb_idxes[0])) # offset 175 - - def _check_label(want_label_str, got_label): - self.assertEqual(got_label.label, want_label_str) - self.assertEqual(got_label.type, interface.LabelType.POSITIVE) - self.assertEqual(got_label.provenance, 'test_dataset') - - # There are two jabwar annotations for the window at offset 185. - offset_185_labels = db.get_labels(emb_idxes[1]) - self.assertLen(offset_185_labels, 2) - _check_label('jabwar', offset_185_labels[0]) - _check_label('jabwar', offset_185_labels[1]) - - offset_275_labels = db.get_labels(emb_idxes[2]) - self.assertLen(offset_275_labels, 1) - _check_label('hawama', offset_275_labels[0]) - - self.assertEmpty(db.get_labels(emb_idxes[3])) # offset 230 - - offset_235_labels = db.get_labels(emb_idxes[4]) - self.assertLen(offset_235_labels, 1) - _check_label('ercfra', offset_235_labels[0]) - if __name__ == '__main__': absltest.main() diff --git a/chirp/projects/agile2/tests/ingest_annotations_test.py b/chirp/projects/agile2/tests/ingest_annotations_test.py new file mode 100644 index 00000000..1cef6c3c --- /dev/null +++ b/chirp/projects/agile2/tests/ingest_annotations_test.py @@ -0,0 +1,175 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for annotation ingestion.""" + +import os +import shutil +import tempfile + +from chirp import path_utils +from chirp.projects.agile2 import embed +from chirp.projects.agile2 import ingest_annotations +from chirp.projects.agile2.tests import test_utils +from chirp.projects.hoplite import db_loader +from chirp.projects.hoplite import interface +from chirp.taxonomy import annotations +from chirp.taxonomy import annotations_fns +from etils import epath +from ml_collections import config_dict +import numpy as np + +from absl.testing import absltest + + +class IngestAnnotationsTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why + # we use `tempdir` directly. + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def make_annotated_data(self): + sites = ['site_1', 'site_2', 'site_3'] + filenames = ['foo', 'bar', 'baz'] + test_utils.make_wav_files(self.tempdir, sites, filenames, file_len_s=60.0) + classes = ['x', 'y', 'z'] + + # Make a collection of random annotations. + annos = [] + for i, s in enumerate(sites): + for j, f in enumerate(filenames): + for k, c in enumerate(classes): + annos.append( + annotations.TimeWindowAnnotation( + filename=f'{s}/{f}_{s}.wav', + start_time_s=9 * i + 3 * j + k, + end_time_s=9 * i + 3 * j + k + 1, + namespace='test', + label=[c], + ) + ) + + annos_path = os.path.join(self.tempdir, 'annos.csv') + annotations.write_annotations_csv(annos_path, annos) + return annos_path, annos + + def test_embed_and_ingest_annotations(self): + rng = np.random.default_rng(42) + db = test_utils.make_db( + self.tempdir, + 'in_mem', + num_embeddings=0, + rng=rng, + ) + emb_offsets = [175, 185, 275, 230, 235] + emb_idxes = [] + for offset in emb_offsets: + emb_idx = db.insert_embedding( + embedding=rng.normal([db.embedding_dimension()]), + source=interface.EmbeddingSource( + dataset_name='hawaii', + source_id='UHH_001_S01_20161121_150000.flac', + offsets=np.array([offset]), + ), + ) + emb_idxes.append(emb_idx) + + hawaii_annos_path = path_utils.get_absolute_path( + 'projects/agile2/tests/testdata/hawaii.csv' + ) + ingestor = ingest_annotations.AnnotatedDatasetIngestor( + base_path=hawaii_annos_path.parent, + audio_glob='*/*.flac', + dataset_name='hawaii', + annotation_filename='hawaii.csv', + annotation_load_fn=annotations_fns.load_cornell_annotations, + ) + inserted_labels = ingestor.ingest_dataset( + db, window_size_s=5.0, provenance='test_dataset' + ) + self.assertSetEqual( + set(inserted_labels.keys()), {'jabwar', 'hawama', 'ercfra'} + ) + # Check that individual labels are correctly applied. + # The Hawai'i test data CSV contains a total of five annotations. + # The window at offset 175 should have no labels. + self.assertEmpty(db.get_labels(emb_idxes[0])) # offset 175 + + def _check_label(want_label_str, got_label): + self.assertEqual(got_label.label, want_label_str) + self.assertEqual(got_label.type, interface.LabelType.POSITIVE) + self.assertEqual(got_label.provenance, 'test_dataset') + + # There are two jabwar annotations for the window at offset 185. + offset_185_labels = db.get_labels(emb_idxes[1]) + self.assertLen(offset_185_labels, 2) + _check_label('jabwar', offset_185_labels[0]) + _check_label('jabwar', offset_185_labels[1]) + + offset_275_labels = db.get_labels(emb_idxes[2]) + self.assertLen(offset_275_labels, 1) + _check_label('hawama', offset_275_labels[0]) + + self.assertEmpty(db.get_labels(emb_idxes[3])) # offset 230 + + offset_235_labels = db.get_labels(emb_idxes[4]) + self.assertLen(offset_235_labels, 1) + _check_label('ercfra', offset_235_labels[0]) + + def test_ingest_annotations(self): + annos_path, annos = self.make_annotated_data() + self.assertLen(annos, 27) + + def _loader_fn(x): + annos = annotations.read_annotations_csv(x, namespace='somedata') + return annotations.annotations_to_dataframe(annos) + + ingestor = ingest_annotations.AnnotatedDatasetIngestor( + base_path=epath.Path(self.tempdir), + audio_glob='*/*.wav', + dataset_name='test', + annotation_filename=annos_path, + annotation_load_fn=_loader_fn, + ) + placeholder_model_config = config_dict.ConfigDict() + placeholder_model_config.embedding_size = 32 + placeholder_model_config.sample_rate = 16000 + model_config = embed.ModelConfig( + model_key='placeholder_model', + embedding_dim=32, + model_config=placeholder_model_config, + ) + db, ingestor_class_counts = ingest_annotations.embed_annotated_dataset( + ds_choice=ingestor, + db_path=os.path.join(self.tempdir), + db_model_config=model_config, + ) + self.assertEqual(db.count_embeddings(), 60 * 9) + + ingestor.ingest_dataset(db, window_size_s=1.0) + class_counts = db.get_class_counts() + for lbl in ('x', 'y', 'z'): + self.assertEqual(class_counts[lbl], 9) + self.assertEqual(ingestor_class_counts[lbl], 9) + + +if __name__ == '__main__': + absltest.main() diff --git a/chirp/taxonomy/annotations.py b/chirp/taxonomy/annotations.py index 39d98d26..d7b198f9 100644 --- a/chirp/taxonomy/annotations.py +++ b/chirp/taxonomy/annotations.py @@ -50,7 +50,12 @@ def annotations_to_dataframe( ) -def write_annotations_csv(filepath, annotations): +def write_annotations_csv( + filepath: str | epath.Path, + annotations: Sequence[TimeWindowAnnotation], + label_separator: str = ' ', +) -> None: + """Write annotations to a CSV file.""" fieldnames = [f.name for f in dataclasses.fields(TimeWindowAnnotation)] fieldnames.remove('namespace') with epath.Path(filepath).open('w') as f: @@ -58,10 +63,30 @@ def write_annotations_csv(filepath, annotations): dr.writeheader() for anno in annotations: anno_dict = {f: getattr(anno, f) for f in fieldnames} - anno_dict['label'] = ' '.join(anno_dict['label']) + anno_dict['label'] = label_separator.join(anno_dict['label']) dr.writerow(anno_dict) +def read_annotations_csv( + annotations_filepath: epath.Path, namespace: str, label_separator: str = ' ' +) -> Sequence[TimeWindowAnnotation]: + """Read annotations as written by write_annotations_csv.""" + got_annotations = [] + with epath.Path(annotations_filepath).open('r') as f: + dr = csv.DictReader(f) + for row in dr: + got_annotations.append( + TimeWindowAnnotation( + filename=row['filename'], + namespace=namespace, + start_time_s=float(row['start_time_s']), + end_time_s=float(row['end_time_s']), + label=row['label'].split(label_separator), + ) + ) + return got_annotations + + def read_dataset_annotations_csvs( filepaths: Sequence[epath.Path], filename_fn: Callable[[epath.Path, dict[str, str]], str], diff --git a/chirp/taxonomy/annotations_fns.py b/chirp/taxonomy/annotations_fns.py index 9b51322d..e1350514 100644 --- a/chirp/taxonomy/annotations_fns.py +++ b/chirp/taxonomy/annotations_fns.py @@ -137,10 +137,12 @@ def load_weldy_annotations(annotations_path: epath.Path) -> pd.DataFrame: return segments -def load_anuraset_annotations(annotations_path: epath.Path) -> pd.DataFrame: +def load_anuraset_annotations( + annotations_path: epath.Path, prefix: str = 'raw_data' +) -> pd.DataFrame: """Loads raw audio annotations from https://zenodo.org/records/8342596.""" filename_fn = lambda _, row: os.path.join( # pylint: disable=g-long-lambda - row['filename'].split('_')[0], row['filename'].strip() + prefix, row['filename'].split('_')[0], row['filename'].strip() ) start_time_fn = lambda row: float(row['start_time_s']) end_time_fn = lambda row: float(row['end_time_s'])