Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

write to parquet #679

Merged
merged 11 commits into from
Aug 7, 2024
87 changes: 66 additions & 21 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import numpy as np
import tensorflow as tf
import tqdm
import pandas as pd
import os


@dataclasses.dataclass
Expand Down Expand Up @@ -181,7 +183,7 @@ def classify_batch(batch):
return inference_ds


def write_inference_csv(
def write_inference_file(
embeddings_ds: tf.data.Dataset,
model: interface.LogitsOutputHead,
labels: Sequence[str],
Expand All @@ -190,35 +192,78 @@ def write_inference_csv(
threshold: dict[str, float] | None = None,
exclude_classes: Sequence[str] = ('unknown',),
include_classes: Sequence[str] = (),
row_size: int = 1_000_000,
format: str = 'parquet',
):
"""Write a CSV file of inference results."""
"""Write inference results."""

if format != 'parquet' and format != 'csv':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit cleaner:

if format == 'parquet':
  ...
elif format == 'csv':
  ...
else:
  raise ValueError(...)

raise ValueError('Format must be either "parquet" or "csv"')

if format == 'parquet':
if output_filepath.endswith('.csv'):
output_filepath = output_filepath[:-4]
if not output_filepath.endswith('.parquet'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This second-guessing of the user-intention from the extension and format args is a bit cumbersome.

Maybe we should get the extension from the output file and use that instead of an arg? (and complain if it's not one of our accepted types.)

Then we would have:

if output_filepath.endswith('.parquet'):
  format = 'parquet'
elif output_filepath.endswith('.csv'):
  format = 'csv'
else:
  raise ValueError(...)

which saves an argument and ~12 lines of code.

output_filepath += '.parquet'
os.mkdir(output_filepath)
if format == 'csv':
if output_filepath.endswith('.parquet'):
output_filepath = output_filepath[:-8]
if not output_filepath.endswith('.csv'):
output_filepath += '.csv'

parquet_count = 0
rows = []

inference_ds = get_inference_dataset(embeddings_ds, model)

detection_count = 0
nondetection_count = 0
with open(output_filepath, 'w') as f:
# Write column headers.
headers = ['filename', 'timestamp_s', 'label', 'logit']
f.write(', '.join(headers) + '\n')
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
if label in exclude_classes:
continue
if include_classes and label not in include_classes:
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = '{:.2f}'.format(ex['logits'][t, i])
if format == 'csv':
f = open(output_filepath, 'w')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to use the with open(...) as f because it ensures that the file will be properly flushed and closed if an exception arises, or if we return early for some reason.

headers = ['filename', 'timestamp_s', 'label', 'logit']
# Write column headers if CSV format
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this comment can be deleted now

if format == 'csv':
f.write(','.join(headers) + '\n')
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
if label in exclude_classes:
continue
if include_classes and label not in include_classes:
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = ex['logits'][t, i]
if format == 'parquet':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe simpler:

Write a helper function flush_rows(output_path, shard_num, rows, format, headers) which writes everything in rows to a file. Then all of the writing logic is centralized; you can call the function here and below when you deal with the remainder rows.

This also helps with the csv file handling; you just open the file and write to it when you're flushing the data to disk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That make it SO much cleaner

row = {
headers[0]: ex["filename"].decode("utf-8"),
headers[1]: offset,
headers[2]: label,
headers[3]: logit,
}
rows.append(row)
if len(rows) >= row_size:
tmp_df = pd.DataFrame(rows)
tmp_df.to_parquet(f'{output_filepath}/part.{parquet_count}.parquet')
parquet_count += 1
rows = []
elif format == 'csv':
row = [
ex['filename'].decode('utf-8'),
'{:.2f}'.format(offset),
label,
logit,
'{:.2f}'.format(logit),
]
f.write(','.join(row) + '\n')
detection_count += 1
else:
nondetection_count += 1
detection_count += 1
else:
nondetection_count += 1
# write remaining rows if parquet format
if format == 'parquet' and rows:
tmp_df = pd.DataFrame(rows)
tmp_df.to_parquet(f'{output_filepath}/part.{parquet_count}.parquet')
if format == 'csv':
f.close()
print('\n\n\n Detection count: ', detection_count)
print('NonDetection count: ', nondetection_count)
print('NonDetection count: ', nondetection_count)
89 changes: 88 additions & 1 deletion chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@

"""Test small-model classification."""

import os
import tempfile

from chirp.inference import interface
import pandas as pd

from chirp.inference import interface, tf_examples
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from chirp.taxonomy import namespace
import numpy as np

from absl.testing import absltest
from absl.testing import parameterized
import shutil


class ClassifyTest(parameterized.TestCase):
Expand Down Expand Up @@ -98,7 +102,90 @@ def test_train_linear_model(self):
restored_logits = restored_model(query)
error = np.abs(restored_logits - logits).sum()
self.assertEqual(error, 0)

def write_random_embeddings(self, embedding_dim, filenames, tempdir):
"""Write random embeddings to a temporary directory."""
rng = np.random.default_rng(42)
with tf_examples.EmbeddingsTFRecordMultiWriter(
output_dir=tempdir, num_files=1
) as file_writer:
for filename in filenames:
embedding = rng.normal(size=(1, 1, embedding_dim)).astype(np.float32)
model_outputs = interface.InferenceOutputs(embedding)
example = tf_examples.model_outputs_to_tf_example(
model_outputs=model_outputs,
file_id=filename,
audio=np.array([]),
timestamp_offset_s=0,
write_raw_audio=False,
write_separated_audio=False,
write_embeddings=True,
write_logits=False,
)
file_writer.write(example.SerializeToString())
file_writer.flush()

def test_write_inference_file(self):
"""Test writing inference files."""
tempdir = tempfile.mkdtemp()

# copy from test_train_linear_model to get the model
embedding_dim = 128
num_classes = 4
model = classify.get_linear_model(embedding_dim, num_classes)

classes = ['a', 'b', 'c', 'd']
logits_model = interface.LogitsOutputHead(
model_path=os.path.join(tempdir, 'model'),
logits_key='some_model',
logits_model=model,
class_list=namespace.ClassList('classes', classes),
)

# make a fake embeddings dataset
filenames = [f'file_{i}' for i in range(100)]

self.write_random_embeddings(embedding_dim, filenames, tempdir)

embeddings_ds = tf_examples.create_embeddings_dataset(embeddings_dir=tempdir)

parquet_path = os.path.join(tempdir, 'output.parquet')
csv_path = os.path.join(tempdir, 'output.csv')

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=parquet_path,
embedding_hop_size_s=5.0,
row_size=10,
format='parquet',
)

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=csv_path,
embedding_hop_size_s=5.0,
format='csv',
)

parquet = pd.read_parquet(parquet_path)
parquet['filename_i'] = parquet['filename'].str.split('_').str[1].astype(int)
parquet = parquet.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

csv = pd.read_csv(csv_path)
csv['filename_i'] = csv['filename'].str.split('_').str[1].astype(int)
csv = csv.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

n_expected_rows = len(filenames) * len(classes)
self.assertTrue(np.allclose(parquet['logit'], csv['logit'], atol=1e-2))
self.assertEqual(len(parquet), n_expected_rows)
self.assertEqual(len(csv), n_expected_rows)

shutil.rmtree(tempdir)


if __name__ == '__main__':
absltest.main()