Skip to content

Commit

Permalink
Fix ClassList CSV write method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551435006
  • Loading branch information
sdenton4 authored and copybara-github committed Jul 27, 2023
1 parent 5f9285c commit 7e9748c
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 21 deletions.
2 changes: 1 addition & 1 deletion chirp/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __post_init__(self):

self.model = tf.saved_model.load(model_path)
with label_csv_path.open('r') as f:
self.class_list = namespace.ClassList.from_csv(f)
self.class_list = namespace.ClassList.from_csv(f.readlines())

# Check whether the model support polymorphic batch shape.
sig = self.model.signatures['serving_default']
Expand Down
40 changes: 20 additions & 20 deletions chirp/taxonomy/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class Namespace:
classes: frozenset[str]

def __post_init__(self):
if UNKNOWN_LABEL in self.classes:
raise ValueError("unknown class")
self.classes = frozenset([c for c in self.classes if c != UNKNOWN_LABEL])


@dataclasses.dataclass
Expand All @@ -54,10 +53,10 @@ class Mapping:
The source and target namespace are referred to by their name. This name must
be resolved using the taxonomy database.
Note that labels cannot be mapped to unknown. Instead, these labels should be
simply excluded from the mapping. The end-user is responsible for deciding
whether to map missing keys to unknown or whether to raise an error, e.g.,
by using:
Note that labels (other than unknown) cannot be mapped to unknown. Instead,
these labels should be simply excluded from the mapping. The end-user is
responsible for deciding whether to map missing keys to unknown or whether to
raise an error, e.g., by using:
mapping.mapped_pairs.get(source_label, namespace.UNKNOWN_LABEL)
Expand All @@ -73,8 +72,14 @@ class Mapping:
mapped_pairs: dict[str, str]

def __post_init__(self):
if UNKNOWN_LABEL in self.mapped_pairs.values():
raise ValueError("unknown target class")
for k, v in self.mapped_pairs.items():
if v == UNKNOWN_LABEL and k != UNKNOWN_LABEL:
raise ValueError("unknown target class")

def __getitem__(self, key):
if key == UNKNOWN_LABEL:
return UNKNOWN_LABEL
return self.mapped_pairs[key]


@dataclasses.dataclass
Expand Down Expand Up @@ -116,7 +121,7 @@ def from_csv(cls, csv_data: Iterable[str]) -> "ClassList":
"""
reader = csv.reader(csv_data)
namespace = next(reader)[0]
classes = tuple(row[0].strip() for row in reader)
classes = tuple(row[0].strip() for row in reader if row)
return ClassList(namespace, classes)

def to_csv(self) -> str:
Expand All @@ -133,8 +138,9 @@ def to_csv(self) -> str:
"""
buffer = io.StringIO(newline="")
writer = csv.writer(buffer)
writer.writerow(self.namespace)
writer.writerows(self.classes)
writer.writerow([self.namespace])
for class_ in self.classes:
writer.writerow([class_])
return buffer.getvalue()

def get_class_map_tf_lookup(
Expand Down Expand Up @@ -189,9 +195,7 @@ def get_namespace_map_tf_lookup(
k: i for i, k in enumerate(target_class_list.classes)
}
keys = list(range(len(self.classes)))
values = [
target_class_indices[mapping.mapped_pairs[k]] for k in self.classes
]
values = [target_class_indices[mapping[k]] for k in self.classes]
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
default_value=-1,
Expand All @@ -203,11 +207,7 @@ def apply_namespace_mapping(self, mapping: Mapping) -> ClassList:
raise ValueError("mapping source namespace does not match class list's")
return ClassList(
mapping.target_namespace,
tuple(
dict.fromkeys(
mapping.mapped_pairs[class_] for class_ in self.classes
)
),
tuple(dict.fromkeys(mapping[class_] for class_ in self.classes)),
)

def get_class_map_matrix(
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_class_map_matrix(
target_idxs = {k: i for i, k in enumerate(target_class_list.classes)}
for i, class_ in enumerate(self.classes):
if mapping is not None:
class_ = mapping.mapped_pairs[class_]
class_ = mapping[class_]
if class_ in target_idxs:
j = target_idxs[class_]
matrix = matrix.at[i, j].set(1)
Expand Down
24 changes: 24 additions & 0 deletions chirp/tests/namespace_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

"""Tests for namespace_db."""

import io
import tempfile

from absl import logging
from chirp.taxonomy import namespace
from chirp.taxonomy import namespace_db
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -62,6 +66,26 @@ def test_class_maps(self):
table.lookup(tf.constant([i], dtype=tf.int64)).numpy()[0], 0
)

def test_class_map_csv(self):
cl = namespace.ClassList(
'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob')
)
cl_csv = cl.to_csv()
with io.StringIO(cl_csv) as f:
got_cl = namespace.ClassList.from_csv(f)
self.assertEqual(got_cl.namespace, 'ebird2021')
self.assertEqual(got_cl.classes, ('amecro', 'amegfi', 'amered', 'amerob'))

# Check that writing with tf.io.gfile behaves as expected, as newline
# behavior may be different than working with StringIO.
with tempfile.NamedTemporaryFile(suffix='.csv') as f:
with tf.io.gfile.GFile(f.name, 'w') as gf:
gf.write(cl_csv)
with open(f.name, 'r') as f:
got_cl = namespace.ClassList.from_csv(f.readlines())
self.assertEqual(got_cl.namespace, 'ebird2021')
self.assertEqual(got_cl.classes, ('amecro', 'amegfi', 'amered', 'amerob'))

def test_namespace_class_list_closure(self):
# Ensure that all classes in class lists appear in their namespace.
db = namespace_db.load_db()
Expand Down
6 changes: 6 additions & 0 deletions chirp/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from chirp.models import efficientnet
from chirp.models import frontend
from chirp.preprocessing import pipeline
from chirp.taxonomy import namespace
from chirp.tests import fake_dataset
from chirp.train import classifier
from clu import checkpoint
Expand Down Expand Up @@ -183,6 +184,11 @@ def test_export_model(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.train_dir, "label.csv"))
)
with open(os.path.join(self.train_dir, "label.csv")) as f:
got_class_list = namespace.ClassList.from_csv(f.readlines())
# Check equality of the ClassList with the Model Bundle.
self.assertEqual(model_bundle.class_lists["label"], got_class_list)

self.assertTrue(
tf.io.gfile.exists(
os.path.join(self.train_dir, "savedmodel/saved_model.pb")
Expand Down

0 comments on commit 7e9748c

Please sign in to comment.