Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agile modeling 16bit #685

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions chirp/inference/classify/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class ClassifierMetrics:


def get_two_layer_model(
num_hiddens: int, embedding_dim: int, num_classes: int, batch_norm: bool
num_hiddens: int, embedding_dim: int, num_classes: int, batch_norm: bool, dtype: str='float32'
) -> tf.keras.Model:
"""Create a simple two-layer Keras model."""
layers = [tf.keras.Input(shape=[embedding_dim])]
layers = [tf.keras.Input(shape=[embedding_dim], dtype=tf.dtypes.as_dtype(dtype))]
if batch_norm:
layers.append(tf.keras.layers.BatchNormalization())
layers += [
Expand All @@ -52,12 +52,20 @@ def get_two_layer_model(
return model


def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model:
# def get_linear_model_old(embedding_dim: int, num_classes: int) -> tf.keras.Model:
# """Create a simple linear Keras model."""
# model = tf.keras.Sequential([
# tf.keras.Input(shape=[embedding_dim], dtype=tf.float16),
# tf.keras.layers.Dense(num_classes),
# ])
# return model


def get_linear_model(embedding_dim: int, num_classes: int, dtype: str="float32") -> tf.keras.Model:
"""Create a simple linear Keras model."""
model = tf.keras.Sequential([
tf.keras.Input(shape=[embedding_dim]),
tf.keras.layers.Dense(num_classes),
])
input_layer = tf.keras.layers.Input(shape=[embedding_dim], dtype=tf.dtypes.as_dtype(dtype))
dense_layer = tf.keras.layers.Dense(num_classes, dtype=dtype)
model = tf.keras.Model(inputs=input_layer, outputs=dense_layer(input_layer))
return model


Expand Down Expand Up @@ -166,15 +174,15 @@ def classify_batch(batch):
emb = batch[tf_examples.EMBEDDING]
emb_shape = tf.shape(emb)
flat_emb = tf.reshape(emb, [-1, emb_shape[-1]])
logits = model(flat_emb)
logits = model.logits_model(flat_emb)
logits = tf.reshape(
logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]]
)
# Take the maximum logit over channels.
logits = tf.reduce_max(logits, axis=-2)
batch['logits'] = logits
return batch

inference_ds = embeddings_ds.map(
classify_batch, num_parallel_calls=tf.data.AUTOTUNE
)
Expand Down
2 changes: 2 additions & 0 deletions chirp/inference/search/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BootstrapState:
embeddings_dataset: tf.data.Dataset | None = None
source_map: Callable[[str, float], str] | None = None
baw_auth_token: str = ''
baw_domain: str = 'api.acousticobservatory.org/'

def __post_init__(self):
if self.embedding_model is None:
Expand All @@ -62,6 +63,7 @@ def __post_init__(self):
self.source_map = functools.partial(
baw_utils.make_baw_audio_url_from_file_id,
window_size_s=window_size_s,
baw_domain=self.baw_domain,
)
else:
self.source_map = lambda file_id, offset: filesystem_source_map(
Expand Down
24 changes: 19 additions & 5 deletions chirp/inference/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def search_embeddings_parallel(
random_sample: bool = False,
invert_sort_score: bool = False,
filter_fn: Callable[[Any], bool] | None = None,
quit_after: int | None = None
):
"""Run a brute-force search.

Expand Down Expand Up @@ -296,7 +297,10 @@ def search_embeddings_parallel(
results = TopKSearchResults(top_k=top_k)
all_distances = []
try:
for ex in tqdm.tqdm(embeddings_dataset.as_numpy_iterator()):
for i, ex in enumerate(tqdm.tqdm(embeddings_dataset.as_numpy_iterator())):
if quit_after is not None and i >= quit_after:
print("quitting early because quit_after is set")
break
all_distances.append(ex['scores'].reshape([-1]))
if results.will_filter(ex['max_sort_score']):
continue
Expand All @@ -311,9 +315,10 @@ def search_embeddings_parallel(
)
results.update(result)
except KeyboardInterrupt:
pass
all_distances = np.concatenate(all_distances)
return results, all_distances
print("quitting search early because of keyboard interrupt")
finally:
all_distances = np.concatenate(all_distances)
return results, all_distances


def classifer_search_embeddings_parallel(
Expand All @@ -331,12 +336,21 @@ def classifer_search_embeddings_parallel(
Returns:
TopKSearchResults and all logits.
"""

# logits model behaves differently depending on whether it's been
# saved and loaded or not
if hasattr(embeddings_classifier.logits_model, 'signatures'):
signature = embeddings_classifier.logits_model.signatures["serving_default"]
input_specs = signature.structured_input_signature[1]
model_input_dtype = list(input_specs.values())[0].dtype
else:
model_input_dtype = embeddings_classifier.logits_model.input.dtype

def classify_batch(batch, query_embedding_batch):
del query_embedding_batch
emb = batch[tf_examples.EMBEDDING]
emb_shape = tf.shape(emb)
flat_emb = tf.cast(tf.reshape(emb, [-1, emb_shape[-1]]), tf.float32)
flat_emb = tf.cast(tf.reshape(emb, [-1, emb_shape[-1]]), model_input_dtype)
logits = embeddings_classifier(flat_emb)
logits = tf.reshape(
logits, [emb_shape[0], emb_shape[1], tf.shape(logits)[-1]]
Expand Down
36 changes: 32 additions & 4 deletions chirp/inference/tests/classify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Test small-model classification."""

import tempfile
import tensorflow as tf

from chirp.inference import interface
from chirp.inference.classify import classify
Expand All @@ -35,12 +36,13 @@ def make_merged_dataset(
rng: np.random.RandomState,
num_classes: int = 4,
embedding_dim: int = 16,
dtype: np.dtype = np.float32
):
"""Create a MergedDataset with random data."""
# Merged dataset's data dict contains keys:
# ['embeddings', 'filename', 'label', 'label_str', 'label_hot']
data = {}
data['embeddings'] = np.float32(
data['embeddings'] = dtype(
rng.normal(size=(num_points, embedding_dim))
)
data['label'] = rng.integers(0, num_classes, size=num_points)
Expand All @@ -55,18 +57,35 @@ def make_merged_dataset(
embedding_dim=embedding_dim,
labels=letters[:num_classes],
)


def test_train_linear_model(self):

@parameterized.product(
training_embedding_dtype=[np.float32, np.float16],
model_input_dtype=[np.float32, np.float16],
query_embedding_dtype=[np.float32, np.float16],
num_hiddens=[-1, 1]
)
def test_train_linear_model(self,
training_embedding_dtype,
model_input_dtype,
query_embedding_dtype,
num_hiddens):
embedding_dim = 16
num_classes = 4
num_points = 100
model = classify.get_linear_model(embedding_dim, num_classes)
if num_hiddens == -1:
model = classify.get_linear_model(embedding_dim, num_classes, dtype=model_input_dtype)
else:
model = classify.get_two_layer_model(num_hiddens, embedding_dim, num_classes, batch_norm=True, dtype=model_input_dtype)

rng = np.random.default_rng(42)
merged = self.make_merged_dataset(
num_points=num_points,
rng=rng,
num_classes=num_classes,
embedding_dim=embedding_dim,
dtype=training_embedding_dtype
)
unused_metrics = classify.train_embedding_model(
model,
Expand All @@ -78,7 +97,7 @@ def test_train_linear_model(self):
batch_size=16,
learning_rate=0.01,
)
query = rng.normal(size=(num_points, embedding_dim)).astype(np.float32)
query = rng.normal(size=(num_points, embedding_dim)).astype(query_embedding_dtype)

logits = model(query)

Expand All @@ -95,7 +114,16 @@ def test_train_linear_model(self):
restored_model = interface.LogitsOutputHead.from_config_file(
logits_model_dir
)

restored_logits = restored_model(query)
#debug:
print(f'training_embedding_dtype {training_embedding_dtype}')
print(f'model_input_dtype {model_input_dtype}')
print(f'query_embedding_dtype {query_embedding_dtype}')
print(f'original_model_input_signature: {model.inputs[0].dtype}')
print(f"restored_model_input_signature: {restored_model.logits_model.signatures['serving_default'].structured_input_signature}")
print(f"prediction dtype: {restored_logits.dtype}")

error = np.abs(restored_logits - logits).sum()
self.assertEqual(error, 0)

Expand Down
Loading