From 4a66bb5c96ac232320f9753038d9c7f8e5083832 Mon Sep 17 00:00:00 2001 From: mschulist Date: Tue, 6 Aug 2024 12:07:58 -0700 Subject: [PATCH] rename flush rows and remove old comments --- chirp/inference/classify/classify.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/chirp/inference/classify/classify.py b/chirp/inference/classify/classify.py index c7b917df..88edcd79 100644 --- a/chirp/inference/classify/classify.py +++ b/chirp/inference/classify/classify.py @@ -183,7 +183,7 @@ def classify_batch(batch): ) return inference_ds -def flush_rows( +def flush_inference_rows( output_path: epath.Path, shard_num: int, rows: list[dict[str, str]], @@ -239,7 +239,6 @@ def write_inference_file( detection_count = 0 nondetection_count = 0 headers = ['filename', 'timestamp_s', 'label', 'logit'] - # Write column headers if CSV format for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()): for t in range(ex['logits'].shape[0]): for i, label in enumerate(labels): @@ -258,13 +257,13 @@ def write_inference_file( } rows.append(row) if len(rows) >= shard_size: - flush_rows(output_filepath, shard_num, rows, format, headers) + flush_inference_rows(output_filepath, shard_num, rows, format, headers) rows = [] shard_num += 1 detection_count += 1 else: nondetection_count += 1 # write remaining rows - flush_rows(output_filepath, shard_num, rows, format, headers) + flush_inference_rows(output_filepath, shard_num, rows, format, headers) print('\n\n\n Detection count: ', detection_count) print('NonDetection count: ', nondetection_count) \ No newline at end of file