From 1d2107202965244c502e4cb79c28d036e4425a37 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 8 Aug 2024 18:29:14 -0700 Subject: [PATCH] Tool for importing TFRecord embedding datasets into new database format. PiperOrigin-RevId: 661060569 --- chirp/inference/classify/classify.py | 101 ++++--------- chirp/inference/tests/classify_test.py | 88 +---------- chirp/projects/agile2/convert_legacy.py | 108 ++++++++++++++ .../agile2/tests/convert_legacy_test.py | 141 ++++++++++++++++++ 4 files changed, 278 insertions(+), 160 deletions(-) create mode 100644 chirp/projects/agile2/convert_legacy.py create mode 100644 chirp/projects/agile2/tests/convert_legacy_test.py diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index 88edcd79..0e1a1103 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -25,9 +25,6 @@ import numpy as np import tensorflow as tf import tqdm -import pandas as pd -import os -from etils import epath @dataclasses.dataclass @@ -183,87 +180,45 @@ def classify_batch(batch): ) return inference_ds -def flush_inference_rows( - output_path: epath.Path, - shard_num: int, - rows: list[dict[str, str]], - format: str, - headers: list[str], -): - """Helper method to write rows to disk.""" - if format == 'csv': - if shard_num == 0: - with output_path.open('w') as f: - f.write(','.join(headers) + '\n') - with output_path.open('a') as f: - for row in rows: - csv_row = [ - '{:.2f}'.format(row.get(h, '')) if isinstance(row.get(h, ''), np.float32) else row.get(h, '') - for h in row - ] - f.write(','.join(csv_row) + '\n') - elif format == 'parquet': - output_path.mkdir(parents=True, exist_ok=True) - parquet_path = output_path / f'part.{shard_num}.parquet' - pd.DataFrame(rows).to_parquet(parquet_path) - else: - raise ValueError('Output format must be either csv or parquet') - -def write_inference_file( +def write_inference_csv( embeddings_ds: tf.data.Dataset, model: interface.LogitsOutputHead, labels: Sequence[str], - output_filepath: epath.PathLike, + output_filepath: str, embedding_hop_size_s: float, threshold: dict[str, float] | None = None, exclude_classes: Sequence[str] = ('unknown',), include_classes: Sequence[str] = (), - shard_size: int = 1_000_000, ): - """Write inference results.""" - output_filepath = epath.Path(output_filepath) - - if str(output_filepath).endswith('.csv'): - format = 'csv' - elif str(output_filepath).endswith('.parquet'): - format = 'parquet' - else: - raise ValueError('Output file must end with either .csv or .parquet') - - shard_num = 0 - rows = [] - + """Write a CSV file of inference results.""" inference_ds = get_inference_dataset(embeddings_ds, model) detection_count = 0 nondetection_count = 0 - headers = ['filename', 'timestamp_s', 'label', 'logit'] - for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): - for t in range(ex['logits'].shape[0]): - for i, label in enumerate(labels): - if label in exclude_classes: - continue - if include_classes and label not in include_classes: - continue - if threshold is None or ex['logits'][t, i] > threshold[label]: - offset = ex['timestamp_s'] + t * embedding_hop_size_s - logit = ex['logits'][t, i] - row = { - headers[0]: ex["filename"].decode("utf-8"), - headers[1]: np.float32(offset), - headers[2]: label, - headers[3]: np.float32(logit), - } - rows.append(row) - if len(rows) >= shard_size: - flush_inference_rows(output_filepath, shard_num, rows, format, headers) - rows = [] - shard_num += 1 - detection_count += 1 - else: - nondetection_count += 1 - # write remaining rows - flush_inference_rows(output_filepath, shard_num, rows, format, headers) + with open(output_filepath, 'w') as f: + # Write column headers. + headers = ['filename', 'timestamp_s', 'label', 'logit'] + f.write(', '.join(headers) + '\n') + for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): + for t in range(ex['logits'].shape[0]): + for i, label in enumerate(labels): + if label in exclude_classes: + continue + if include_classes and label not in include_classes: + continue + if threshold is None or ex['logits'][t, i] > threshold[label]: + offset = ex['timestamp_s'] + t * embedding_hop_size_s + logit = '{:.2f}'.format(ex['logits'][t, i]) + row = [ + ex['filename'].decode('utf-8'), + '{:.2f}'.format(offset), + label, + logit, + ] + f.write(','.join(row) + '\n') + detection_count += 1 + else: + nondetection_count += 1 print('\n\n\n Detection count: ', detection_count) - print('NonDetection count: ', nondetection_count) \ No newline at end of file + print('NonDetection count: ', nondetection_count) diff --git a/chirp/inference/tests/classify_test.py b/chirp/inference/tests/classify_test.py index db5f5d1a..3475442c 100644 --- a/chirp/inference/tests/classify_test.py +++ b/chirp/inference/tests/classify_test.py @@ -15,12 +15,9 @@ """Test small-model classification.""" -import os import tempfile -import pandas as pd - -from chirp.inference import interface, tf_examples +from chirp.inference import interface from chirp.inference.classify import classify from chirp.inference.classify import data_lib from chirp.taxonomy import namespace @@ -28,7 +25,6 @@ from absl.testing import absltest from absl.testing import parameterized -import shutil class ClassifyTest(parameterized.TestCase): @@ -102,89 +98,7 @@ def test_train_linear_model(self): restored_logits = restored_model(query) error = np.abs(restored_logits - logits).sum() self.assertEqual(error, 0) - - def write_random_embeddings(self, embedding_dim, filenames, tempdir): - """Write random embeddings to a temporary directory.""" - rng = np.random.default_rng(42) - with tf_examples.EmbeddingsTFRecordMultiWriter( - output_dir=tempdir, num_files=1 - ) as file_writer: - for filename in filenames: - embedding = rng.normal(size=(1, 1, embedding_dim)).astype(np.float32) - model_outputs = interface.InferenceOutputs(embedding) - example = tf_examples.model_outputs_to_tf_example( - model_outputs=model_outputs, - file_id=filename, - audio=np.array([]), - timestamp_offset_s=0, - write_raw_audio=False, - write_separated_audio=False, - write_embeddings=True, - write_logits=False, - ) - file_writer.write(example.SerializeToString()) - file_writer.flush() - def test_write_inference_file(self): - """Test writing inference files.""" - tempdir = tempfile.mkdtemp() - - # copy from test_train_linear_model to get the model - embedding_dim = 128 - num_classes = 4 - model = classify.get_linear_model(embedding_dim, num_classes) - - classes = ['a', 'b', 'c', 'd'] - logits_model = interface.LogitsOutputHead( - model_path=os.path.join(tempdir, 'model'), - logits_key='some_model', - logits_model=model, - class_list=namespace.ClassList('classes', classes), - ) - - # make a fake embeddings dataset - filenames = [f'file_{i}' for i in range(101)] - - self.write_random_embeddings(embedding_dim, filenames, tempdir) - - embeddings_ds = tf_examples.create_embeddings_dataset(embeddings_dir=tempdir) - - parquet_path = os.path.join(tempdir, 'output.parquet') - csv_path = os.path.join(tempdir, 'output.csv') - - classify.write_inference_file( - embeddings_ds=embeddings_ds, - model=logits_model, - labels=classes, - output_filepath=parquet_path, - embedding_hop_size_s=5.0, - shard_size=10, - ) - - classify.write_inference_file( - embeddings_ds=embeddings_ds, - model=logits_model, - labels=classes, - output_filepath=csv_path, - embedding_hop_size_s=5.0, - shard_size=10, - ) - - parquet = pd.read_parquet(parquet_path) - parquet['filename_i'] = parquet['filename'].str.split('_').str[1].astype(int) - parquet = parquet.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True) - - csv = pd.read_csv(csv_path) - csv['filename_i'] = csv['filename'].str.split('_').str[1].astype(int) - csv = csv.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True) - - n_expected_rows = len(filenames) * len(classes) - self.assertTrue(np.allclose(parquet['logit'], csv['logit'], atol=1e-2)) - self.assertEqual(len(parquet), n_expected_rows) - self.assertEqual(len(csv), n_expected_rows) - - shutil.rmtree(tempdir) - if __name__ == '__main__': absltest.main() diff --git a/chirp/projects/agile2/convert_legacy.py b/chirp/projects/agile2/convert_legacy.py new file mode 100644 index 00000000..95d1e086 --- /dev/null +++ b/chirp/projects/agile2/convert_legacy.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion for TFRecord embeddings to Hoplite DB.""" + +import os +from chirp.inference import embed_lib +from chirp.inference import tf_examples +from chirp.projects.agile2 import embed +from chirp.projects.hoplite import in_mem_impl +from chirp.projects.hoplite import interface +from chirp.projects.hoplite import sqlite_impl +from etils import epath +import numpy as np +import tqdm + + +def convert_tfrecords( + embeddings_path: str, + db_type: str, + dataset_name: str, + max_count: int = -1, + **kwargs, +): + """Convert a TFRecord embeddings dataset to a Hoplite DB.""" + ds = tf_examples.create_embeddings_dataset( + embeddings_path, + 'embeddings-*', + ) + # Peek at one embedding to get the embedding dimension. + for ex in ds.as_numpy_iterator(): + emb_dim = ex['embedding'].shape[-1] + break + else: + raise ValueError('No embeddings found.') + + if db_type == 'sqlite': + db_path = kwargs['db_path'] + if epath.Path(db_path).exists(): + raise ValueError(f'DB path {db_path} already exists.') + db = sqlite_impl.SQLiteGraphSearchDB.create(db_path, embedding_dim=emb_dim) + elif db_type == 'in_mem': + db = in_mem_impl.InMemoryGraphSearchDB.create( + embedding_dim=emb_dim, + max_size=kwargs['max_size'], + degree_bound=kwargs['degree_bound'], + ) + else: + raise ValueError(f'Unknown db type: {db_type}') + db.setup() + + # Convert embedding config to new format and insert into the DB. + legacy_config = embed_lib.load_embedding_config(embeddings_path) + model_config = embed.ModelConfig( + model_key=legacy_config.embed_fn_config.model_key, + model_config=legacy_config.embed_fn_config.model_config, + ) + file_id_depth = legacy_config.embed_fn_config['file_id_depth'] + audio_globs = [] + for glob in legacy_config.source_file_patterns: + new_glob = glob.split('/')[-file_id_depth - 1 :] + audio_globs.append(new_glob) + + embed_config = embed.EmbedConfig( + audio_globs={dataset_name: tuple(audio_globs)}, + min_audio_len_s=legacy_config.embed_fn_config.min_audio_s, + target_sample_rate_hz=legacy_config.embed_fn_config.get( + 'target_sample_rate_hz', -1 + ), + ) + db.insert_metadata('legacy_config', legacy_config) + db.insert_metadata('embed_config', embed_config.to_config_dict()) + db.insert_metadata('model_config', model_config.to_config_dict()) + hop_size_s = model_config.model_config.hop_size_s + + for ex in tqdm.tqdm(ds.as_numpy_iterator()): + embs = ex['embedding'] + print(embs.shape) + flat_embeddings = np.reshape(embs, [-1, embs.shape[-1]]) + file_id = str(ex['filename'], 'utf8') + offset_s = ex['timestamp_s'] + if max_count > 0 and db.count_embeddings() >= max_count: + break + for i in range(flat_embeddings.shape[0]): + embedding = flat_embeddings[i] + offset = np.array(offset_s + hop_size_s * i) + source = interface.EmbeddingSource(dataset_name, file_id, offset) + db.insert_embedding(embedding, source) + if max_count > 0 and db.count_embeddings() >= max_count: + break + db.commit() + num_embeddings = db.count_embeddings() + print('\n\nTotal embeddings : ', num_embeddings) + hours_equiv = num_embeddings / 60 / 60 * hop_size_s + print(f'\n\nHours of audio equivalent : {hours_equiv:.2f}') + return db diff --git a/chirp/projects/agile2/tests/convert_legacy_test.py b/chirp/projects/agile2/tests/convert_legacy_test.py new file mode 100644 index 00000000..6eedf5fa --- /dev/null +++ b/chirp/projects/agile2/tests/convert_legacy_test.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for conversion from previous agile modeling format.""" + +import os +import shutil +import tempfile + +from chirp import audio_utils +from chirp.inference import embed_lib +from chirp.inference import tf_examples +from chirp.projects.agile2 import convert_legacy +from chirp.projects.agile2.tests import test_utils +from etils import epath +from ml_collections import config_dict + +from absl.testing import absltest +from absl.testing import parameterized + + +class ConvertLegacyTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + # `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why + # we use `tempdir` directly. + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tempdir) + + def write_placeholder_embeddings(self, audio_glob, source_infos, embed_dir): + """Utility method for writing embeddings with the placeholder model.""" + # Set up embedding function. + config = config_dict.ConfigDict() + model_kwargs = { + 'sample_rate': 16000, + 'embedding_size': 128, + 'make_embeddings': True, + 'make_logits': False, + 'make_separated_audio': False, + 'window_size_s': 5.0, + 'hop_size_s': 5.0, + } + embed_fn_config = config_dict.ConfigDict() + embed_fn_config.write_embeddings = True + embed_fn_config.write_logits = False + embed_fn_config.write_separated_audio = False + embed_fn_config.write_raw_audio = False + embed_fn_config.write_frontend = False + embed_fn_config.model_key = 'placeholder_model' + embed_fn_config.model_config = model_kwargs + embed_fn_config.min_audio_s = 0.1 + embed_fn_config.file_id_depth = 1 + config.embed_fn_config = embed_fn_config + config.source_file_patterns = [audio_glob] + config.num_shards_per_file = -1 + config.shard_len_s = -1 + + epath.Path(embed_dir).mkdir(parents=True, exist_ok=True) + embed_lib.maybe_write_config(config, epath.Path(embed_dir)) + + embed_fn = embed_lib.EmbedFn(**embed_fn_config) + embed_fn.setup() + + # Write embeddings. + audio_loader = lambda fp, offset: audio_utils.load_audio_window( + fp, offset, sample_rate=16000, window_size_s=120.0 + ) + audio_iterator = audio_utils.multi_load_audio_window( + filepaths=[s.filepath for s in source_infos], + offsets=[0 for s in source_infos], + audio_loader=audio_loader, + ) + with tf_examples.EmbeddingsTFRecordMultiWriter( + output_dir=embed_dir, num_files=1 + ) as file_writer: + for source_info, audio in zip(source_infos, audio_iterator): + file_id = source_info.file_id(1) + offset_s = source_info.shard_num * source_info.shard_len_s + example = embed_fn.audio_to_example(file_id, offset_s, audio) + file_writer.write(example.SerializeToString()) + file_writer.flush() + + @parameterized.product( + db_type=( + 'in_mem', + 'sqlite', + ), + ) + def test_convert_legacy(self, db_type): + classes = ['pos', 'neg'] + filenames = ['foo', 'bar', 'baz'] + audio_glob = test_utils.make_wav_files( + self.tempdir, classes, filenames, file_len_s=60.0 + ) + source_infos = embed_lib.create_source_infos([audio_glob], shard_len_s=-1) + self.assertLen(source_infos, len(classes) * len(filenames)) + embed_dir = os.path.join(self.tempdir, 'embeddings') + self.write_placeholder_embeddings(audio_glob, source_infos, embed_dir) + + if db_type == 'sqlite': + kwargs = {'db_path': os.path.join(self.tempdir, 'db.sqlite')} + elif db_type == 'in_mem': + kwargs = { + 'max_size': 100, + 'degree_bound': 10, + } + else: + raise ValueError(f'Unknown db type: {db_type}') + + db = convert_legacy.convert_tfrecords( + embeddings_path=embed_dir, + db_type=db_type, + dataset_name='test', + **kwargs, + ) + # There are six one-minute test files, so we should get 72 embeddings. + self.assertEqual(db.count_embeddings(), 72) + metadata = db.get_metadata(key=None) + self.assertIn('legacy_config', metadata) + self.assertIn('embed_config', metadata) + self.assertIn('model_config', metadata) + + +if __name__ == '__main__': + absltest.main()