Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

write to parquet #679

Merged
merged 11 commits into from
Aug 7, 2024
102 changes: 74 additions & 28 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import numpy as np
import tensorflow as tf
import tqdm
import pandas as pd
import os
from etils import epath


@dataclasses.dataclass
Expand Down Expand Up @@ -180,45 +183,88 @@ def classify_batch(batch):
)
return inference_ds

def flush_rows(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a slightly more descriptive name, like flush_inference_rows.

output_path: epath.Path,
shard_num: int,
rows: list[dict[str, str]],
format: str,
headers: list[str],
):
"""Helper method to write rows to disk."""
if format == 'csv':
if shard_num == 0:
with output_path.open('w') as f:
f.write(','.join(headers) + '\n')
with output_path.open('a') as f:
for row in rows:
csv_row = [
'{:.2f}'.format(row.get(h, '')) if isinstance(row.get(h, ''), np.float32) else row.get(h, '')
for h in row
]
f.write(','.join(csv_row) + '\n')
elif format == 'parquet':
output_path.mkdir(parents=True, exist_ok=True)
parquet_path = output_path / f'part.{shard_num}.parquet'
pd.DataFrame(rows).to_parquet(parquet_path)
else:
raise ValueError('Output format must be either csv or parquet')

def write_inference_csv(

def write_inference_file(
embeddings_ds: tf.data.Dataset,
model: interface.LogitsOutputHead,
labels: Sequence[str],
output_filepath: str,
output_filepath: epath.PathLike,
embedding_hop_size_s: float,
threshold: dict[str, float] | None = None,
exclude_classes: Sequence[str] = ('unknown',),
include_classes: Sequence[str] = (),
shard_size: int = 1_000_000,
):
"""Write a CSV file of inference results."""
"""Write inference results."""
output_filepath = epath.Path(output_filepath)

if str(output_filepath).endswith('.csv'):
format = 'csv'
elif str(output_filepath).endswith('.parquet'):
format = 'parquet'
else:
raise ValueError('Output file must end with either .csv or .parquet')

shard_num = 0
rows = []

inference_ds = get_inference_dataset(embeddings_ds, model)

detection_count = 0
nondetection_count = 0
with open(output_filepath, 'w') as f:
# Write column headers.
headers = ['filename', 'timestamp_s', 'label', 'logit']
f.write(', '.join(headers) + '\n')
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
if label in exclude_classes:
continue
if include_classes and label not in include_classes:
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = '{:.2f}'.format(ex['logits'][t, i])
row = [
ex['filename'].decode('utf-8'),
'{:.2f}'.format(offset),
label,
logit,
]
f.write(','.join(row) + '\n')
detection_count += 1
else:
nondetection_count += 1
headers = ['filename', 'timestamp_s', 'label', 'logit']
# Write column headers if CSV format
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this comment can be deleted now

for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
if label in exclude_classes:
continue
if include_classes and label not in include_classes:
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = ex['logits'][t, i]
row = {
headers[0]: ex["filename"].decode("utf-8"),
headers[1]: np.float32(offset),
headers[2]: label,
headers[3]: np.float32(logit),
}
rows.append(row)
if len(rows) >= shard_size:
flush_rows(output_filepath, shard_num, rows, format, headers)
rows = []
shard_num += 1
detection_count += 1
else:
nondetection_count += 1
# write remaining rows
flush_rows(output_filepath, shard_num, rows, format, headers)
print('\n\n\n Detection count: ', detection_count)
print('NonDetection count: ', nondetection_count)
print('NonDetection count: ', nondetection_count)
88 changes: 87 additions & 1 deletion chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@

"""Test small-model classification."""

import os
import tempfile

from chirp.inference import interface
import pandas as pd

from chirp.inference import interface, tf_examples
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from chirp.taxonomy import namespace
import numpy as np

from absl.testing import absltest
from absl.testing import parameterized
import shutil


class ClassifyTest(parameterized.TestCase):
Expand Down Expand Up @@ -98,7 +102,89 @@ def test_train_linear_model(self):
restored_logits = restored_model(query)
error = np.abs(restored_logits - logits).sum()
self.assertEqual(error, 0)

def write_random_embeddings(self, embedding_dim, filenames, tempdir):
"""Write random embeddings to a temporary directory."""
rng = np.random.default_rng(42)
with tf_examples.EmbeddingsTFRecordMultiWriter(
output_dir=tempdir, num_files=1
) as file_writer:
for filename in filenames:
embedding = rng.normal(size=(1, 1, embedding_dim)).astype(np.float32)
model_outputs = interface.InferenceOutputs(embedding)
example = tf_examples.model_outputs_to_tf_example(
model_outputs=model_outputs,
file_id=filename,
audio=np.array([]),
timestamp_offset_s=0,
write_raw_audio=False,
write_separated_audio=False,
write_embeddings=True,
write_logits=False,
)
file_writer.write(example.SerializeToString())
file_writer.flush()

def test_write_inference_file(self):
"""Test writing inference files."""
tempdir = tempfile.mkdtemp()

# copy from test_train_linear_model to get the model
embedding_dim = 128
num_classes = 4
model = classify.get_linear_model(embedding_dim, num_classes)

classes = ['a', 'b', 'c', 'd']
logits_model = interface.LogitsOutputHead(
model_path=os.path.join(tempdir, 'model'),
logits_key='some_model',
logits_model=model,
class_list=namespace.ClassList('classes', classes),
)

# make a fake embeddings dataset
filenames = [f'file_{i}' for i in range(101)]

self.write_random_embeddings(embedding_dim, filenames, tempdir)

embeddings_ds = tf_examples.create_embeddings_dataset(embeddings_dir=tempdir)

parquet_path = os.path.join(tempdir, 'output.parquet')
csv_path = os.path.join(tempdir, 'output.csv')

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=parquet_path,
embedding_hop_size_s=5.0,
shard_size=10,
)

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=csv_path,
embedding_hop_size_s=5.0,
shard_size=10,
)

parquet = pd.read_parquet(parquet_path)
parquet['filename_i'] = parquet['filename'].str.split('_').str[1].astype(int)
parquet = parquet.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

csv = pd.read_csv(csv_path)
csv['filename_i'] = csv['filename'].str.split('_').str[1].astype(int)
csv = csv.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

n_expected_rows = len(filenames) * len(classes)
self.assertTrue(np.allclose(parquet['logit'], csv['logit'], atol=1e-2))
self.assertEqual(len(parquet), n_expected_rows)
self.assertEqual(len(csv), n_expected_rows)

shutil.rmtree(tempdir)


if __name__ == '__main__':
absltest.main()