Skip to content

Commit

Permalink
clean up writing code, add ePath ability
Browse files Browse the repository at this point in the history
  • Loading branch information
mschulist committed Aug 5, 2024
1 parent bbc8e61 commit a273206
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 53 deletions.
99 changes: 50 additions & 49 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import tqdm
import pandas as pd
import os
from etils import epath


@dataclasses.dataclass
Expand Down Expand Up @@ -182,49 +183,63 @@ def classify_batch(batch):
)
return inference_ds

def flush_rows(
output_path: epath.Path,
shard_num: int,
rows: list[dict[str, str]],
format: str,
headers: list[str],
):
"""Helper method to write rows to disk."""
if format == 'csv':
if shard_num == 0:
with output_path.open('w') as f:
f.write(','.join(headers) + '\n')
with output_path.open('a') as f:
for row in rows:
csv_row = [
'{:.2f}'.format(row.get(h, '')) if isinstance(row.get(h, ''), np.float32) else row.get(h, '')
for h in row
]
f.write(','.join(csv_row) + '\n')
elif format == 'parquet':
output_path.mkdir(parents=True, exist_ok=True)
parquet_path = output_path / f'part.{shard_num}.parquet'
pd.DataFrame(rows).to_parquet(parquet_path)
else:
raise ValueError('Output format must be either csv or parquet')


def write_inference_file(
embeddings_ds: tf.data.Dataset,
model: interface.LogitsOutputHead,
labels: Sequence[str],
output_filepath: str,
output_filepath: epath.PathLike,
embedding_hop_size_s: float,
threshold: dict[str, float] | None = None,
exclude_classes: Sequence[str] = ('unknown',),
include_classes: Sequence[str] = (),
row_size: int = 1_000_000,
format: str = 'parquet',
shard_size: int = 1_000_000,
):
"""Write inference results."""
output_filepath = epath.Path(output_filepath)

if format != 'parquet' and format != 'csv':
raise ValueError('Format must be either "parquet" or "csv"')

if format == 'parquet':
if output_filepath.endswith('.csv'):
output_filepath = output_filepath[:-4]
if not output_filepath.endswith('.parquet'):
output_filepath += '.parquet'
os.mkdir(output_filepath)
if format == 'csv':
if output_filepath.endswith('.parquet'):
output_filepath = output_filepath[:-8]
if not output_filepath.endswith('.csv'):
output_filepath += '.csv'
if str(output_filepath).endswith('.csv'):
format = 'csv'
elif str(output_filepath).endswith('.parquet'):
format = 'parquet'
else:
raise ValueError('Output file must end with either .csv or .parquet')

parquet_count = 0
shard_num = 0
rows = []

inference_ds = get_inference_dataset(embeddings_ds, model)

detection_count = 0
nondetection_count = 0
if format == 'csv':
f = open(output_filepath, 'w')
headers = ['filename', 'timestamp_s', 'label', 'logit']
# Write column headers if CSV format
if format == 'csv':
f.write(','.join(headers) + '\n')
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
Expand All @@ -235,35 +250,21 @@ def write_inference_file(
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = ex['logits'][t, i]
if format == 'parquet':
row = {
headers[0]: ex["filename"].decode("utf-8"),
headers[1]: offset,
headers[2]: label,
headers[3]: logit,
}
rows.append(row)
if len(rows) >= row_size:
tmp_df = pd.DataFrame(rows)
tmp_df.to_parquet(f'{output_filepath}/part.{parquet_count}.parquet')
parquet_count += 1
rows = []
elif format == 'csv':
row = [
ex['filename'].decode('utf-8'),
'{:.2f}'.format(offset),
label,
'{:.2f}'.format(logit),
]
f.write(','.join(row) + '\n')
row = {
headers[0]: ex["filename"].decode("utf-8"),
headers[1]: np.float32(offset),
headers[2]: label,
headers[3]: np.float32(logit),
}
rows.append(row)
if len(rows) >= shard_size:
flush_rows(output_filepath, shard_num, rows, format, headers)
rows = []
shard_num += 1
detection_count += 1
else:
nondetection_count += 1
# write remaining rows if parquet format
if format == 'parquet' and rows:
tmp_df = pd.DataFrame(rows)
tmp_df.to_parquet(f'{output_filepath}/part.{parquet_count}.parquet')
if format == 'csv':
f.close()
# write remaining rows
flush_rows(output_filepath, shard_num, rows, format, headers)
print('\n\n\n Detection count: ', detection_count)
print('NonDetection count: ', nondetection_count)
7 changes: 3 additions & 4 deletions chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_write_inference_file(self):
)

# make a fake embeddings dataset
filenames = [f'file_{i}' for i in range(100)]
filenames = [f'file_{i}' for i in range(101)]

self.write_random_embeddings(embedding_dim, filenames, tempdir)

Expand All @@ -158,8 +158,7 @@ def test_write_inference_file(self):
labels=classes,
output_filepath=parquet_path,
embedding_hop_size_s=5.0,
row_size=10,
format='parquet',
shard_size=10,
)

classify.write_inference_file(
Expand All @@ -168,7 +167,7 @@ def test_write_inference_file(self):
labels=classes,
output_filepath=csv_path,
embedding_hop_size_s=5.0,
format='csv',
shard_size=10,
)

parquet = pd.read_parquet(parquet_path)
Expand Down

0 comments on commit a273206

Please sign in to comment.