Skip to content

Commit

Permalink
Tool for importing TFRecord embedding datasets into new database format.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661060569
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 9, 2024
1 parent 0402b78 commit 1d21072
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 160 deletions.
101 changes: 28 additions & 73 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
import numpy as np
import tensorflow as tf
import tqdm
import pandas as pd
import os
from etils import epath


@dataclasses.dataclass
Expand Down Expand Up @@ -183,87 +180,45 @@ def classify_batch(batch):
)
return inference_ds

def flush_inference_rows(
output_path: epath.Path,
shard_num: int,
rows: list[dict[str, str]],
format: str,
headers: list[str],
):
"""Helper method to write rows to disk."""
if format == 'csv':
if shard_num == 0:
with output_path.open('w') as f:
f.write(','.join(headers) + '\n')
with output_path.open('a') as f:
for row in rows:
csv_row = [
'{:.2f}'.format(row.get(h, '')) if isinstance(row.get(h, ''), np.float32) else row.get(h, '')
for h in row
]
f.write(','.join(csv_row) + '\n')
elif format == 'parquet':
output_path.mkdir(parents=True, exist_ok=True)
parquet_path = output_path / f'part.{shard_num}.parquet'
pd.DataFrame(rows).to_parquet(parquet_path)
else:
raise ValueError('Output format must be either csv or parquet')


def write_inference_file(
def write_inference_csv(
embeddings_ds: tf.data.Dataset,
model: interface.LogitsOutputHead,
labels: Sequence[str],
output_filepath: epath.PathLike,
output_filepath: str,
embedding_hop_size_s: float,
threshold: dict[str, float] | None = None,
exclude_classes: Sequence[str] = ('unknown',),
include_classes: Sequence[str] = (),
shard_size: int = 1_000_000,
):
"""Write inference results."""
output_filepath = epath.Path(output_filepath)

if str(output_filepath).endswith('.csv'):
format = 'csv'
elif str(output_filepath).endswith('.parquet'):
format = 'parquet'
else:
raise ValueError('Output file must end with either .csv or .parquet')

shard_num = 0
rows = []

"""Write a CSV file of inference results."""
inference_ds = get_inference_dataset(embeddings_ds, model)

detection_count = 0
nondetection_count = 0
headers = ['filename', 'timestamp_s', 'label', 'logit']
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
if label in exclude_classes:
continue
if include_classes and label not in include_classes:
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = ex['logits'][t, i]
row = {
headers[0]: ex["filename"].decode("utf-8"),
headers[1]: np.float32(offset),
headers[2]: label,
headers[3]: np.float32(logit),
}
rows.append(row)
if len(rows) >= shard_size:
flush_inference_rows(output_filepath, shard_num, rows, format, headers)
rows = []
shard_num += 1
detection_count += 1
else:
nondetection_count += 1
# write remaining rows
flush_inference_rows(output_filepath, shard_num, rows, format, headers)
with open(output_filepath, 'w') as f:
# Write column headers.
headers = ['filename', 'timestamp_s', 'label', 'logit']
f.write(', '.join(headers) + '\n')
for ex in tqdm.tqdm(inference_ds.as_numpy_iterator()):
for t in range(ex['logits'].shape[0]):
for i, label in enumerate(labels):
if label in exclude_classes:
continue
if include_classes and label not in include_classes:
continue
if threshold is None or ex['logits'][t, i] > threshold[label]:
offset = ex['timestamp_s'] + t * embedding_hop_size_s
logit = '{:.2f}'.format(ex['logits'][t, i])
row = [
ex['filename'].decode('utf-8'),
'{:.2f}'.format(offset),
label,
logit,
]
f.write(','.join(row) + '\n')
detection_count += 1
else:
nondetection_count += 1
print('\n\n\n Detection count: ', detection_count)
print('NonDetection count: ', nondetection_count)
print('NonDetection count: ', nondetection_count)
88 changes: 1 addition & 87 deletions chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,16 @@

"""Test small-model classification."""

import os
import tempfile

import pandas as pd

from chirp.inference import interface, tf_examples
from chirp.inference import interface
from chirp.inference.classify import classify
from chirp.inference.classify import data_lib
from chirp.taxonomy import namespace
import numpy as np

from absl.testing import absltest
from absl.testing import parameterized
import shutil


class ClassifyTest(parameterized.TestCase):
Expand Down Expand Up @@ -102,89 +98,7 @@ def test_train_linear_model(self):
restored_logits = restored_model(query)
error = np.abs(restored_logits - logits).sum()
self.assertEqual(error, 0)

def write_random_embeddings(self, embedding_dim, filenames, tempdir):
"""Write random embeddings to a temporary directory."""
rng = np.random.default_rng(42)
with tf_examples.EmbeddingsTFRecordMultiWriter(
output_dir=tempdir, num_files=1
) as file_writer:
for filename in filenames:
embedding = rng.normal(size=(1, 1, embedding_dim)).astype(np.float32)
model_outputs = interface.InferenceOutputs(embedding)
example = tf_examples.model_outputs_to_tf_example(
model_outputs=model_outputs,
file_id=filename,
audio=np.array([]),
timestamp_offset_s=0,
write_raw_audio=False,
write_separated_audio=False,
write_embeddings=True,
write_logits=False,
)
file_writer.write(example.SerializeToString())
file_writer.flush()

def test_write_inference_file(self):
"""Test writing inference files."""
tempdir = tempfile.mkdtemp()

# copy from test_train_linear_model to get the model
embedding_dim = 128
num_classes = 4
model = classify.get_linear_model(embedding_dim, num_classes)

classes = ['a', 'b', 'c', 'd']
logits_model = interface.LogitsOutputHead(
model_path=os.path.join(tempdir, 'model'),
logits_key='some_model',
logits_model=model,
class_list=namespace.ClassList('classes', classes),
)

# make a fake embeddings dataset
filenames = [f'file_{i}' for i in range(101)]

self.write_random_embeddings(embedding_dim, filenames, tempdir)

embeddings_ds = tf_examples.create_embeddings_dataset(embeddings_dir=tempdir)

parquet_path = os.path.join(tempdir, 'output.parquet')
csv_path = os.path.join(tempdir, 'output.csv')

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=parquet_path,
embedding_hop_size_s=5.0,
shard_size=10,
)

classify.write_inference_file(
embeddings_ds=embeddings_ds,
model=logits_model,
labels=classes,
output_filepath=csv_path,
embedding_hop_size_s=5.0,
shard_size=10,
)

parquet = pd.read_parquet(parquet_path)
parquet['filename_i'] = parquet['filename'].str.split('_').str[1].astype(int)
parquet = parquet.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

csv = pd.read_csv(csv_path)
csv['filename_i'] = csv['filename'].str.split('_').str[1].astype(int)
csv = csv.sort_values(by=['filename_i', 'timestamp_s']).reset_index(drop=True)

n_expected_rows = len(filenames) * len(classes)
self.assertTrue(np.allclose(parquet['logit'], csv['logit'], atol=1e-2))
self.assertEqual(len(parquet), n_expected_rows)
self.assertEqual(len(csv), n_expected_rows)

shutil.rmtree(tempdir)


if __name__ == '__main__':
absltest.main()
108 changes: 108 additions & 0 deletions chirp/projects/agile2/convert_legacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Conversion for TFRecord embeddings to Hoplite DB."""

import os
from chirp.inference import embed_lib
from chirp.inference import tf_examples
from chirp.projects.agile2 import embed
from chirp.projects.hoplite import in_mem_impl
from chirp.projects.hoplite import interface
from chirp.projects.hoplite import sqlite_impl
from etils import epath
import numpy as np
import tqdm


def convert_tfrecords(
embeddings_path: str,
db_type: str,
dataset_name: str,
max_count: int = -1,
**kwargs,
):
"""Convert a TFRecord embeddings dataset to a Hoplite DB."""
ds = tf_examples.create_embeddings_dataset(
embeddings_path,
'embeddings-*',
)
# Peek at one embedding to get the embedding dimension.
for ex in ds.as_numpy_iterator():
emb_dim = ex['embedding'].shape[-1]
break
else:
raise ValueError('No embeddings found.')

if db_type == 'sqlite':
db_path = kwargs['db_path']
if epath.Path(db_path).exists():
raise ValueError(f'DB path {db_path} already exists.')
db = sqlite_impl.SQLiteGraphSearchDB.create(db_path, embedding_dim=emb_dim)
elif db_type == 'in_mem':
db = in_mem_impl.InMemoryGraphSearchDB.create(
embedding_dim=emb_dim,
max_size=kwargs['max_size'],
degree_bound=kwargs['degree_bound'],
)
else:
raise ValueError(f'Unknown db type: {db_type}')
db.setup()

# Convert embedding config to new format and insert into the DB.
legacy_config = embed_lib.load_embedding_config(embeddings_path)
model_config = embed.ModelConfig(
model_key=legacy_config.embed_fn_config.model_key,
model_config=legacy_config.embed_fn_config.model_config,
)
file_id_depth = legacy_config.embed_fn_config['file_id_depth']
audio_globs = []
for glob in legacy_config.source_file_patterns:
new_glob = glob.split('/')[-file_id_depth - 1 :]
audio_globs.append(new_glob)

embed_config = embed.EmbedConfig(
audio_globs={dataset_name: tuple(audio_globs)},
min_audio_len_s=legacy_config.embed_fn_config.min_audio_s,
target_sample_rate_hz=legacy_config.embed_fn_config.get(
'target_sample_rate_hz', -1
),
)
db.insert_metadata('legacy_config', legacy_config)
db.insert_metadata('embed_config', embed_config.to_config_dict())
db.insert_metadata('model_config', model_config.to_config_dict())
hop_size_s = model_config.model_config.hop_size_s

for ex in tqdm.tqdm(ds.as_numpy_iterator()):
embs = ex['embedding']
print(embs.shape)
flat_embeddings = np.reshape(embs, [-1, embs.shape[-1]])
file_id = str(ex['filename'], 'utf8')
offset_s = ex['timestamp_s']
if max_count > 0 and db.count_embeddings() >= max_count:
break
for i in range(flat_embeddings.shape[0]):
embedding = flat_embeddings[i]
offset = np.array(offset_s + hop_size_s * i)
source = interface.EmbeddingSource(dataset_name, file_id, offset)
db.insert_embedding(embedding, source)
if max_count > 0 and db.count_embeddings() >= max_count:
break
db.commit()
num_embeddings = db.count_embeddings()
print('\n\nTotal embeddings : ', num_embeddings)
hours_equiv = num_embeddings / 60 / 60 * hop_size_s
print(f'\n\nHours of audio equivalent : {hours_equiv:.2f}')
return db
Loading

0 comments on commit 1d21072

Please sign in to comment.