diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index 0e1a1103..88edcd79 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -25,6 +25,9 @@ import numpy as np import tensorflow as tf import tqdm +import pandas as pd +import os +from etils import epath @dataclasses.dataclass @@ -180,45 +183,87 @@ def classify_batch(batch): ) return inference_ds +def flush_inference_rows( + output_path: epath.Path, + shard_num: int, + rows: list[dict[str, str]], + format: str, + headers: list[str], +): + """Helper method to write rows to disk.""" + if format == 'csv': + if shard_num == 0: + with output_path.open('w') as f: + f.write(','.join(headers) + '\n') + with output_path.open('a') as f: + for row in rows: + csv_row = [ + '{:.2f}'.format(row.get(h, '')) if isinstance(row.get(h, ''), np.float32) else row.get(h, '') + for h in row + ] + f.write(','.join(csv_row) + '\n') + elif format == 'parquet': + output_path.mkdir(parents=True, exist_ok=True) + parquet_path = output_path / f'part.{shard_num}.parquet' + pd.DataFrame(rows).to_parquet(parquet_path) + else: + raise ValueError('Output format must be either csv or parquet') -def write_inference_csv( + +def write_inference_file( embeddings_ds: tf.data.Dataset, model: interface.LogitsOutputHead, labels: Sequence[str], - output_filepath: str, + output_filepath: epath.PathLike, embedding_hop_size_s: float, threshold: dict[str, float] | None = None, exclude_classes: Sequence[str] = ('unknown',), include_classes: Sequence[str] = (), + shard_size: int = 1_000_000, ): - """Write a CSV file of inference results.""" + """Write inference results.""" + output_filepath = epath.Path(output_filepath) + + if str(output_filepath).endswith('.csv'): + format = 'csv' + elif str(output_filepath).endswith('.parquet'): + format = 'parquet' + else: + raise ValueError('Output file must end with either .csv or .parquet') + + shard_num = 0 + rows = [] + inference_ds = get_inference_dataset(embeddings_ds, model) detection_count = 0 nondetection_count = 0 - with open(output_filepath, 'w') as f: - # Write column headers. - headers = ['filename', 'timestamp_s', 'label', 'logit'] - f.write(', '.join(headers) + '\n') - for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): - for t in range(ex['logits'].shape[0]): - for i, label in enumerate(labels): - if label in exclude_classes: - continue - if include_classes and label not in include_classes: - continue - if threshold is None or ex['logits'][t, i] > threshold[label]: - offset = ex['timestamp_s'] + t * embedding_hop_size_s - logit = '{:.2f}'.format(ex['logits'][t, i]) - row = [ - ex['filename'].decode('utf-8'), - '{:.2f}'.format(offset), - label, - logit, - ] - f.write(','.join(row) + '\n') - detection_count += 1 - else: - nondetection_count += 1 + headers = ['filename', 'timestamp_s', 'label', 'logit'] + for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): + for t in range(ex['logits'].shape[0]): + for i, label in enumerate(labels): + if label in exclude_classes: + continue + if include_classes and label not in include_classes: + continue + if threshold is None or ex['logits'][t, i] > threshold[label]: + offset = ex['timestamp_s'] + t * embedding_hop_size_s + logit = ex['logits'][t, i] + row = { + headers[0]: ex["filename"].decode("utf-8"), + headers[1]: np.float32(offset), + headers[2]: label, + headers[3]: np.float32(logit), + } + rows.append(row) + if len(rows) >= shard_size: + flush_inference_rows(output_filepath, shard_num, rows, format, headers) + rows = [] + shard_num += 1 + detection_count += 1 + else: + nondetection_count += 1 + # write remaining rows + flush_inference_rows(output_filepath, shard_num, rows, format, headers) print('\n\n\n Detection count: ', detection_count) - print('NonDetection count: ', nondetection_count) + print('NonDetection count: ', nondetection_count) \ No newline at end of file diff --git a/chirp/inference/tests/classify_test.py b/chirp/inference/tests/classify_test.py index 3475442c..db5f5d1a 100644 --- a/chirp/inference/tests/classify_test.py +++ b/chirp/inference/tests/classify_test.py @@ -15,9 +15,12 @@ """Test small-model classification.""" +import os import tempfile -from chirp.inference import interface +import pandas as pd + +from chirp.inference import interface, tf_examples from chirp.inference.classify import classify from chirp.inference.classify import data_lib from chirp.taxonomy import namespace @@ -25,6 +28,7 @@ from absl.testing import absltest from absl.testing import parameterized +import shutil class ClassifyTest(parameterized.TestCase): @@ -98,7 +102,89 @@ def test_train_linear_model(self): restored_logits = restored_model(query) error = np.abs(restored_logits - logits).sum() self.assertEqual(error, 0) + + def write_random_embeddings(self, embedding_dim, filenames, tempdir): + """Write random embeddings to a temporary directory.""" + rng = np.random.default_rng(42) + with tf_examples.EmbeddingsTFRecordMultiWriter( + output_dir=tempdir, num_files=1 + ) as file_writer: + for filename in filenames: + embedding = rng.normal(size=(1, 1, embedding_dim)).astype(np.float32) + model_outputs = interface.InferenceOutputs(embedding) + example = tf_examples.model_outputs_to_tf_example( + model_outputs=model_outputs, + file_id=filename, + audio=np.array([]), + timestamp_offset_s=0, + write_raw_audio=False, + write_separated_audio=False, + write_embeddings=True, + write_logits=False, + ) + file_writer.write(example.SerializeToString()) + file_writer.flush() + def test_write_inference_file(self): + """Test writing inference files.""" + tempdir = tempfile.mkdtemp() + + # copy from test_train_linear_model to get the model + embedding_dim = 128 + num_classes = 4 + model = classify.get_linear_model(embedding_dim, num_classes) + + classes = ['a', 'b', 'c', 'd'] + logits_model = interface.LogitsOutputHead( + model_path=os.path.join(tempdir, 'model'), + logits_key='some_model', + logits_model=model, + class_list=namespace.ClassList('classes', classes), + ) + + # make a fake embeddings dataset + filenames = [f'file_{i}' for i in range(101)] + + self.write_random_embeddings(embedding_dim, filenames, tempdir) + + embeddings_ds = tf_examples.create_embeddings_dataset(embeddings_dir=tempdir) + + parquet_path = os.path.join(tempdir, 'output.parquet') + csv_path = os.path.join(tempdir, 'output.csv') + + classify.write_inference_file( + embeddings_ds=embeddings_ds, + model=logits_model, + labels=classes, + output_filepath=parquet_path, + embedding_hop_size_s=5.0, + shard_size=10, + ) + + classify.write_inference_file( + embeddings_ds=embeddings_ds, + model=logits_model, + labels=classes, + output_filepath=csv_path, + embedding_hop_size_s=5.0, + shard_size=10, + ) + + parquet = pd.read_parquet(parquet_path) + parquet['filename_i'] = parquet['filename'].str.split('_').str[1].astype(int) + parquet = parquet.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True) + + csv = pd.read_csv(csv_path) + csv['filename_i'] = csv['filename'].str.split('_').str[1].astype(int) + csv = csv.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True) + + n_expected_rows = len(filenames) * len(classes) + self.assertTrue(np.allclose(parquet['logit'], csv['logit'], atol=1e-2)) + self.assertEqual(len(parquet), n_expected_rows) + self.assertEqual(len(csv), n_expected_rows) + + shutil.rmtree(tempdir) + if __name__ == '__main__': absltest.main()