Skip to content

Commit

Permalink
Add presets for fully-annotated datasets and convenience method for c…
Browse files Browse the repository at this point in the history
…reating an annotated embeddings database.

PiperOrigin-RevId: 680614976
  • Loading branch information
sdenton4 authored and copybara-github committed Sep 30, 2024
1 parent acad3eb commit 10bf156
Show file tree
Hide file tree
Showing 6 changed files with 371 additions and 82 deletions.
174 changes: 161 additions & 13 deletions chirp/projects/agile2/ingest_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ingest fully-annotated dataset labels from CSV."""
"""Ingest fully-annotated dataset audio and labels."""

import collections
import dataclasses
from typing import Callable, Sequence
from typing import Callable

from chirp.projects.agile2 import embed
from chirp.projects.hoplite import db_loader
from chirp.projects.hoplite import interface
from chirp.taxonomy import annotations
from chirp.taxonomy import annotations_fns
from etils import epath
from ml_collections import config_dict
import pandas as pd
import tqdm

BASE_PATH = epath.Path('gs://chirp-public-bucket/soundscapes/')


@dataclasses.dataclass
class AnnotatedDatasetIngestor:
"""Add annotations to embeddings DB from CSV annotations.
Note that currently we only add positive labels.
Attributes:
base_path: Base path for the dataset.
audio_glob: Glob for the audio files.
dataset_name: Name of the dataset.
annotation_filename: Filename for the annotations CSV.
annotation_load_fn: Function to load the annotations CSV.
"""

base_path: str
base_path: epath.Path
audio_glob: str
dataset_name: str
annotation_filename: str
annotation_load_fn: Callable[[str | epath.Path], pd.DataFrame]
window_size_s: float
provenance: str = 'dataset'

def ingest_dataset(self, db: interface.GraphSearchDBInterface) -> set[str]:
"""Load annotations and insert labels into the DB."""
def ingest_dataset(
self,
db: interface.GraphSearchDBInterface,
window_size_s: float,
provenance: str = 'annotations_csv',
) -> collections.defaultdict[str, int]:
"""Load annotations and insert labels into the DB.
Args:
db: The DB to insert labels into.
window_size_s: The window size of the embeddings.
provenance: The provenance to use for the labels.
Returns:
A dictionary of ingested label counts.
"""
annos_path = epath.Path(self.base_path) / self.annotation_filename
annos_df = self.annotation_load_fn(annos_path)
lbl_count = 0
lbl_counts = collections.defaultdict(int)
file_ids = annos_df['filename'].unique()
label_set = set()
for file_id in tqdm.tqdm(file_ids):
Expand All @@ -52,17 +78,139 @@ def ingest_dataset(self, db: interface.GraphSearchDBInterface) -> set[str]:
for idx in embedding_ids:
source = db.get_embedding_source(idx)
window_start = source.offsets[0]
window_end = window_start + self.window_size_s
window_end = window_start + window_size_s
emb_annos = source_annos[source_annos['start_time_s'] < window_end]
emb_annos = emb_annos[emb_annos['end_time_s'] > window_start]
# All of the remianing annotations match the target embedding.
for labels in emb_annos['label']:
for label in labels:
label_set.add(label)
lbl = interface.Label(
idx, label, interface.LabelType.POSITIVE, self.provenance
idx,
label,
interface.LabelType.POSITIVE,
provenance=provenance,
)
db.insert_label(lbl)
lbl_count += 1
lbl_counts[label] += 1
lbl_count = sum(lbl_counts.values())
print(f'\nInserted {lbl_count} labels.')
return label_set
return lbl_counts


CORNELL_LOADER = lambda x: annotations_fns.load_cornell_annotations(
x, file_id_prefix='audio/'
)

PRESETS: dict[str, AnnotatedDatasetIngestor] = {
'powdermill': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'powdermill',
audio_glob='*/*.wav',
dataset_name='powdermill',
annotation_filename='powdermill.csv',
annotation_load_fn=annotations_fns.load_powdermill_annotations,
),
'hawaii': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'hawaii',
dataset_name='hawaii',
audio_glob='audio/*.flac',
annotation_filename='annotations.csv',
annotation_load_fn=CORNELL_LOADER,
),
'high_sierras': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'high_sierras',
dataset_name='high_sierras',
audio_glob='audio/*.flac',
annotation_filename='annotations.csv',
annotation_load_fn=CORNELL_LOADER,
),
'coffee_farms': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'coffee_farms',
dataset_name='coffee_farms',
audio_glob='audio/*.flac',
annotation_filename='annotations.csv',
annotation_load_fn=CORNELL_LOADER,
),
'peru': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'peru',
dataset_name='peru',
audio_glob='audio/*.flac',
annotation_filename='annotations.csv',
annotation_load_fn=CORNELL_LOADER,
),
'ssw': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'ssw',
dataset_name='ssw',
audio_glob='audio/*.flac',
annotation_filename='annotations.csv',
annotation_load_fn=CORNELL_LOADER,
),
'sierras_kahl': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'sierras_kahl',
dataset_name='sierras_kahl',
audio_glob='audio/*.flac',
annotation_filename='annotations.csv',
annotation_load_fn=CORNELL_LOADER,
),
'anuraset': AnnotatedDatasetIngestor(
base_path=BASE_PATH / 'anuraset',
dataset_name='anuraset',
audio_glob='raw_data/*/*.wav',
annotation_filename='annotations.csv',
annotation_load_fn=annotations_fns.load_anuraset_annotations,
),
}


def embed_annotated_dataset(
ds_choice: str | AnnotatedDatasetIngestor,
db_path: str,
db_model_config: embed.ModelConfig,
) -> tuple[interface.GraphSearchDBInterface, dict[str, int]]:
"""Embed a fully-annotated dataset to SQLite Hoplite DB.
Args:
ds_choice: The preset name of the dataset to embed. Alternatively, an
AnnotatedDatasetIngestor can be provided.
db_path: The path to the DB.
db_model_config: The model config for the DB.
Returns:
The DB and a dictionary of label counts.
"""
if isinstance(ds_choice, str):
ingestor = PRESETS[ds_choice]
else:
ingestor = ds_choice

db_filepath = f'{db_path}/hoplite_db.sqlite'
epath.Path(db_filepath).parent.mkdir(parents=True, exist_ok=True)
db_config = config_dict.ConfigDict({
'db_path': db_filepath,
'embedding_dim': db_model_config.embedding_dim,
})
db_config = db_loader.DBConfig('sqlite', db_config)
print(ingestor)
audio_srcs_config = embed.EmbedConfig(
audio_globs={
ingestor.dataset_name: (
ingestor.base_path.as_posix(),
ingestor.audio_glob,
)
},
min_audio_len_s=1.0,
)
db = db_config.load_db()
db.setup()
print('Initialized DB located at ', db_filepath)
worker = embed.EmbedWorker(
embed_config=audio_srcs_config, db=db, model_config=db_model_config
)
worker.process_all()
print(f'DB contains {db.count_embeddings()} embeddings.')

class_counts = ingestor.ingest_dataset(
db, window_size_s=worker.embedding_model.window_size_s
)
db.commit()
return db, class_counts
6 changes: 4 additions & 2 deletions chirp/projects/agile2/source_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def iterate_all_sources(
shard_len_s (ie, the final shard).
"""
for dataset_name, (root_dir, file_glob) in self.audio_globs.items():
filepaths = tuple(epath.Path(root_dir).glob(file_glob))
# If root_dir is a URL, the posix path may not match the original string.
base_path = epath.Path(root_dir)
filepaths = tuple(base_path.glob(file_glob))
for filepath in tqdm.tqdm(filepaths):
file_id = filepath.as_posix()[len(root_dir) + 1 :]
file_id = filepath.as_posix()[len(base_path.as_posix()) + 1 :]
if shard_len_s < 0:
yield SourceId(
dataset_name=dataset_name,
Expand Down
63 changes: 0 additions & 63 deletions chirp/projects/agile2/tests/classifier_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
import shutil
import tempfile

from chirp import path_utils
from chirp.projects.agile2 import classifier_data
from chirp.projects.agile2 import ingest_annotations
from chirp.projects.agile2.tests import test_utils
from chirp.projects.hoplite import interface
from chirp.taxonomy import annotations_fns
import numpy as np

from absl.testing import absltest
Expand Down Expand Up @@ -216,66 +213,6 @@ def test_multihot_labels(self):
np.testing.assert_equal(multihot, (0, 1, 1, 0, 0, 1))
np.testing.assert_equal(mask, (0, 1, 1, 1, 0, 1))

def test_ingest_annotations(self):
rng = np.random.default_rng(42)
db = test_utils.make_db(
self.tempdir,
'in_mem',
num_embeddings=0,
rng=rng,
)
emb_offsets = [175, 185, 275, 230, 235]
emb_idxes = []
for offset in emb_offsets:
emb_idx = db.insert_embedding(
embedding=rng.normal([db.embedding_dimension()]),
source=interface.EmbeddingSource(
dataset_name='hawaii',
source_id='UHH_001_S01_20161121_150000.flac',
offsets=np.array([offset]),
),
)
emb_idxes.append(emb_idx)

hawaii_annos_path = path_utils.get_absolute_path(
'projects/agile2/tests/testdata/hawaii.csv'
)
ingestor = ingest_annotations.AnnotatedDatasetIngestor(
base_path=hawaii_annos_path.parent,
dataset_name='hawaii',
annotation_filename='hawaii.csv',
annotation_load_fn=annotations_fns.load_cornell_annotations,
window_size_s=5.0,
provenance='test_dataset',
)
inserted_labels = ingestor.ingest_dataset(db)
self.assertSetEqual(inserted_labels, {'jabwar', 'hawama', 'ercfra'})
# Check that individual labels are correctly applied.
# The Hawai'i test data CSV contains a total of five annotations.
# The window at offset 175 should have no labels.
self.assertEmpty(db.get_labels(emb_idxes[0])) # offset 175

def _check_label(want_label_str, got_label):
self.assertEqual(got_label.label, want_label_str)
self.assertEqual(got_label.type, interface.LabelType.POSITIVE)
self.assertEqual(got_label.provenance, 'test_dataset')

# There are two jabwar annotations for the window at offset 185.
offset_185_labels = db.get_labels(emb_idxes[1])
self.assertLen(offset_185_labels, 2)
_check_label('jabwar', offset_185_labels[0])
_check_label('jabwar', offset_185_labels[1])

offset_275_labels = db.get_labels(emb_idxes[2])
self.assertLen(offset_275_labels, 1)
_check_label('hawama', offset_275_labels[0])

self.assertEmpty(db.get_labels(emb_idxes[3])) # offset 230

offset_235_labels = db.get_labels(emb_idxes[4])
self.assertLen(offset_235_labels, 1)
_check_label('ercfra', offset_235_labels[0])


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 10bf156

Please sign in to comment.