Skip to content

Commit

Permalink
write test and fix errors in write
Browse files Browse the repository at this point in the history
  • Loading branch information
mschulist committed Aug 4, 2024
1 parent da13b22 commit bbc8e61
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 38 deletions.
15 changes: 10 additions & 5 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,15 @@ def write_inference_file(
output_filepath = output_filepath[:-4]
if not output_filepath.endswith('.parquet'):
output_filepath += '.parquet'

parquet_count = 0
os.mkdir(output_filepath)
rows = []
if format == 'csv':
if output_filepath.endswith('.parquet'):
output_filepath = output_filepath[:-8]
if not output_filepath.endswith('.csv'):
output_filepath += '.csv'

parquet_count = 0
rows = []

inference_ds = get_inference_dataset(embeddings_ds, model)

Expand All @@ -229,7 +234,7 @@ def write_inference_file(
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = '{:.2f}'.format(ex['logits'][t, i])
logit = ex['logits'][t, i]
if format == 'parquet':
row = {
headers[0]: ex["filename"].decode("utf-8"),
Expand All @@ -248,7 +253,7 @@ def write_inference_file(
ex['filename'].decode('utf-8'),
'{:.2f}'.format(offset),
label,
logit,
'{:.2f}'.format(logit),
]
f.write(','.join(row) + '\n')
detection_count += 1
Expand Down
97 changes: 64 additions & 33 deletions chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@

import os
import tempfile
from etils import epath

from chirp.inference import embed_lib, interface
import pandas as pd

from chirp.inference import interface, tf_examples
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from chirp.inference.search import bootstrap
from chirp.taxonomy import namespace
import numpy as np

from absl.testing import absltest
from absl.testing import parameterized
from bootstrap_test import BootstrapTest
import shutil


class ClassifyTest(parameterized.TestCase):
Expand Down Expand Up @@ -102,59 +102,90 @@ def test_train_linear_model(self):
restored_logits = restored_model(query)
error = np.abs(restored_logits - logits).sum()
self.assertEqual(error, 0)

def write_random_embeddings(self, embedding_dim, filenames, tempdir):
"""Write random embeddings to a temporary directory."""
rng = np.random.default_rng(42)
with tf_examples.EmbeddingsTFRecordMultiWriter(
output_dir=tempdir, num_files=1
) as file_writer:
for filename in filenames:
embedding = rng.normal(size=(1, 1, embedding_dim)).astype(np.float32)
model_outputs = interface.InferenceOutputs(embedding)
example = tf_examples.model_outputs_to_tf_example(
model_outputs=model_outputs,
file_id=filename,
audio=np.array([]),
timestamp_offset_s=0,
write_raw_audio=False,
write_separated_audio=False,
write_embeddings=True,
write_logits=False,
)
file_writer.write(example.SerializeToString())
file_writer.flush()

def test_write_inference_file(self):
"""Test writing inference files."""
tempdir = tempfile.mkdtemp()

# copy from test_train_linear_model to get the model
embedding_dim = 128
num_classes = 4
model = classify.get_linear_model(embedding_dim, num_classes)

classes = ['a', 'b', 'c', 'd']
logits_model = interface.LogitsOutputHead(
model_path='./test_model',
model_path=os.path.join(tempdir, 'model'),
logits_key='some_model',
logits_model=model,
class_list=namespace.ClassList('classes', classes),
)

# make a fake embeddings dataset
filenames = ['file1', 'file2', 'file3']
bt = BootstrapTest()
bt.setUp()
audio_glob = bt.make_wav_files(classes, filenames)
source_infos = embed_lib.create_source_infos([audio_glob], shard_len_s=5.0)

embed_dir = os.path.join(bt.tempdir, 'embeddings')
labeled_dir = os.path.join(bt.tempdir, 'labeled')
epath.Path(embed_dir).mkdir(parents=True, exist_ok=True)
epath.Path(labeled_dir).mkdir(parents=True, exist_ok=True)
filenames = [f'file_{i}' for i in range(100)]

print(source_infos)
print(bt.tempdir)

bt.write_placeholder_embeddings(audio_glob, source_infos, embed_dir)

bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_path(
embeddings_path=embed_dir,
annotated_path=labeled_dir,
)
print('config hash : ', bootstrap_config.embedding_config_hash())

project_state = bootstrap.BootstrapState(
config=bootstrap_config,
)
self.write_random_embeddings(embedding_dim, filenames, tempdir)

embeddings_ds = project_state.create_embeddings_dataset()
embeddings_ds = tf_examples.create_embeddings_dataset(embeddings_dir=tempdir)

parquet_path = os.path.join(tempdir, 'output.parquet')
csv_path = os.path.join(tempdir, 'output.csv')

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath='./test_output',
output_filepath=parquet_path,
embedding_hop_size_s=5.0,
row_size=1,
format='csv'
row_size=10,
format='parquet',
)

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=csv_path,
embedding_hop_size_s=5.0,
format='csv',
)

parquet = pd.read_parquet(parquet_path)
parquet['filename_i'] = parquet['filename'].str.split('_').str[1].astype(int)
parquet = parquet.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

csv = pd.read_csv(csv_path)
csv['filename_i'] = csv['filename'].str.split('_').str[1].astype(int)
csv = csv.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

n_expected_rows = len(filenames) * len(classes)
self.assertTrue(np.allclose(parquet['logit'], csv['logit'], atol=1e-2))
self.assertEqual(len(parquet), n_expected_rows)
self.assertEqual(len(csv), n_expected_rows)

shutil.rmtree(tempdir)


if __name__ == '__main__':
absltest.main()

0 comments on commit bbc8e61

Please sign in to comment.