Skip to content

Commit

Permalink
rename flush rows and remove old comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mschulist committed Aug 6, 2024
1 parent a273206 commit 4a66bb5
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def classify_batch(batch):
)
return inference_ds

def flush_rows(
def flush_inference_rows(
output_path: epath.Path,
shard_num: int,
rows: list[dict[str, str]],
Expand Down Expand Up @@ -239,7 +239,6 @@ def write_inference_file(
detection_count = 0
nondetection_count = 0
headers = ['filename', 'timestamp_s', 'label', 'logit']
# Write column headers if CSV format
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
Expand All @@ -258,13 +257,13 @@ def write_inference_file(
}
rows.append(row)
if len(rows) >= shard_size:
flush_rows(output_filepath, shard_num, rows, format, headers)
flush_inference_rows(output_filepath, shard_num, rows, format, headers)
rows = []
shard_num += 1
detection_count += 1
else:
nondetection_count += 1
# write remaining rows
flush_rows(output_filepath, shard_num, rows, format, headers)
flush_inference_rows(output_filepath, shard_num, rows, format, headers)
print('\n\n\n Detection count: ', detection_count)
print('NonDetection count: ', nondetection_count)

0 comments on commit 4a66bb5

Please sign in to comment.