diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index 86fc5aad..c7b917df 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -27,6 +27,7 @@ import tqdm import pandas as pd import os +from etils import epath @dataclasses.dataclass @@ -182,49 +183,63 @@ def classify_batch(batch): ) return inference_ds +def flush_rows( + output_path: epath.Path, + shard_num: int, + rows: list[dict[str, str]], + format: str, + headers: list[str], +): + """Helper method to write rows to disk.""" + if format == 'csv': + if shard_num == 0: + with output_path.open('w') as f: + f.write(','.join(headers) + '\n') + with output_path.open('a') as f: + for row in rows: + csv_row = [ + '{:.2f}'.format(row.get(h, '')) if isinstance(row.get(h, ''), np.float32) else row.get(h, '') + for h in row + ] + f.write(','.join(csv_row) + '\n') + elif format == 'parquet': + output_path.mkdir(parents=True, exist_ok=True) + parquet_path = output_path / f'part.{shard_num}.parquet' + pd.DataFrame(rows).to_parquet(parquet_path) + else: + raise ValueError('Output format must be either csv or parquet') + def write_inference_file( embeddings_ds: tf.data.Dataset, model: interface.LogitsOutputHead, labels: Sequence[str], - output_filepath: str, + output_filepath: epath.PathLike, embedding_hop_size_s: float, threshold: dict[str, float] | None = None, exclude_classes: Sequence[str] = ('unknown',), include_classes: Sequence[str] = (), - row_size: int = 1_000_000, - format: str = 'parquet', + shard_size: int = 1_000_000, ): """Write inference results.""" + output_filepath = epath.Path(output_filepath) - if format != 'parquet' and format != 'csv': - raise ValueError('Format must be either "parquet" or "csv"') - - if format == 'parquet': - if output_filepath.endswith('.csv'): - output_filepath = output_filepath[:-4] - if not output_filepath.endswith('.parquet'): - output_filepath += '.parquet' - os.mkdir(output_filepath) - if format == 'csv': - if output_filepath.endswith('.parquet'): - output_filepath = output_filepath[:-8] - if not output_filepath.endswith('.csv'): - output_filepath += '.csv' + if str(output_filepath).endswith('.csv'): + format = 'csv' + elif str(output_filepath).endswith('.parquet'): + format = 'parquet' + else: + raise ValueError('Output file must end with either .csv or .parquet') - parquet_count = 0 + shard_num = 0 rows = [] inference_ds = get_inference_dataset(embeddings_ds, model) detection_count = 0 nondetection_count = 0 - if format == 'csv': - f = open(output_filepath, 'w') headers = ['filename', 'timestamp_s', 'label', 'logit'] # Write column headers if CSV format - if format == 'csv': - f.write(','.join(headers) + '\n') for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): for t in range(ex['logits'].shape[0]): for i, label in enumerate(labels): @@ -235,35 +250,21 @@ def write_inference_file( if threshold is None or ex['logits'][t, i] > threshold[label]: offset = ex['timestamp_s'] + t * embedding_hop_size_s logit = ex['logits'][t, i] - if format == 'parquet': - row = { - headers[0]: ex["filename"].decode("utf-8"), - headers[1]: offset, - headers[2]: label, - headers[3]: logit, - } - rows.append(row) - if len(rows) >= row_size: - tmp_df = pd.DataFrame(rows) - tmp_df.to_parquet(f'{output_filepath}/part.{parquet_count}.parquet') - parquet_count += 1 - rows = [] - elif format == 'csv': - row = [ - ex['filename'].decode('utf-8'), - '{:.2f}'.format(offset), - label, - '{:.2f}'.format(logit), - ] - f.write(','.join(row) + '\n') + row = { + headers[0]: ex["filename"].decode("utf-8"), + headers[1]: np.float32(offset), + headers[2]: label, + headers[3]: np.float32(logit), + } + rows.append(row) + if len(rows) >= shard_size: + flush_rows(output_filepath, shard_num, rows, format, headers) + rows = [] + shard_num += 1 detection_count += 1 else: nondetection_count += 1 - # write remaining rows if parquet format - if format == 'parquet' and rows: - tmp_df = pd.DataFrame(rows) - tmp_df.to_parquet(f'{output_filepath}/part.{parquet_count}.parquet') - if format == 'csv': - f.close() + # write remaining rows + flush_rows(output_filepath, shard_num, rows, format, headers) print('\n\n\n Detection count: ', detection_count) print('NonDetection count: ', nondetection_count) \ No newline at end of file diff --git a/chirp/inference/tests/classify_test.py b/chirp/inference/tests/classify_test.py index d2d1105d..db5f5d1a 100644 --- a/chirp/inference/tests/classify_test.py +++ b/chirp/inference/tests/classify_test.py @@ -143,7 +143,7 @@ def test_write_inference_file(self): ) # make a fake embeddings dataset - filenames = [f'file_{i}' for i in range(100)] + filenames = [f'file_{i}' for i in range(101)] self.write_random_embeddings(embedding_dim, filenames, tempdir) @@ -158,8 +158,7 @@ def test_write_inference_file(self): labels=classes, output_filepath=parquet_path, embedding_hop_size_s=5.0, - row_size=10, - format='parquet', + shard_size=10, ) classify.write_inference_file( @@ -168,7 +167,7 @@ def test_write_inference_file(self): labels=classes, output_filepath=csv_path, embedding_hop_size_s=5.0, - format='csv', + shard_size=10, ) parquet = pd.read_parquet(parquet_path)