Skip to content

Commit

Permalink
Fix ClassList CSV write method and allow mapping 'unknown' to 'unknow…
Browse files Browse the repository at this point in the history
…n' programatically.

PiperOrigin-RevId: 551435006
  • Loading branch information
sdenton4 authored and copybara-github committed Jul 27, 2023
1 parent 5f9285c commit 6965f91
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 17 deletions.
6 changes: 3 additions & 3 deletions chirp/preprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,13 +751,13 @@ def load_tables(
self.source_namespace + '_to_' + key
]
source_taxa_classes = source_class_list.apply_namespace_mapping(
namespace_mapping
namespace_mapping, keep_unknown=True
)
target_taxa_classes = target_classes.apply_namespace_mapping(
namespace_mapping
namespace_mapping, keep_unknown=True
)
namespace_table = source_class_list.get_namespace_map_tf_lookup(
namespace_mapping
namespace_mapping, keep_unknown=True
)
class_table, label_mask = source_taxa_classes.get_class_map_tf_lookup(
target_taxa_classes
Expand Down
78 changes: 64 additions & 14 deletions chirp/taxonomy/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class Mapping:
The source and target namespace are referred to by their name. This name must
be resolved using the taxonomy database.
Note that labels cannot be mapped to unknown. Instead, these labels should be
simply excluded from the mapping. The end-user is responsible for deciding
whether to map missing keys to unknown or whether to raise an error, e.g.,
by using:
Note that labels (other than unknown) cannot be mapped to unknown. Instead,
these labels should be simply excluded from the mapping. The end-user is
responsible for deciding whether to map missing keys to unknown or whether to
raise an error, e.g., by using:
mapping.mapped_pairs.get(source_label, namespace.UNKNOWN_LABEL)
Expand All @@ -73,8 +73,18 @@ class Mapping:
mapped_pairs: dict[str, str]

def __post_init__(self):
if UNKNOWN_LABEL in self.mapped_pairs.values():
raise ValueError("unknown target class")
for k, v in self.mapped_pairs.items():
if v == UNKNOWN_LABEL and k != UNKNOWN_LABEL:
raise ValueError("unknown target class")

def with_unknown(self) -> "Mapping":
if UNKNOWN_LABEL in self.mapped_pairs:
return self
new_mapped_pairs = self.mapped_pairs.copy()
new_mapped_pairs[UNKNOWN_LABEL] = UNKNOWN_LABEL
return Mapping(
self.source_namespace, self.target_namespace, new_mapped_pairs
)


@dataclasses.dataclass
Expand Down Expand Up @@ -116,7 +126,7 @@ def from_csv(cls, csv_data: Iterable[str]) -> "ClassList":
"""
reader = csv.reader(csv_data)
namespace = next(reader)[0]
classes = tuple(row[0].strip() for row in reader)
classes = tuple(row[0].strip() for row in reader if row)
return ClassList(namespace, classes)

def to_csv(self) -> str:
Expand All @@ -133,8 +143,9 @@ def to_csv(self) -> str:
"""
buffer = io.StringIO(newline="")
writer = csv.writer(buffer)
writer.writerow(self.namespace)
writer.writerows(self.classes)
writer.writerow([self.namespace])
for class_ in self.classes:
writer.writerow([class_])
return buffer.getvalue()

def get_class_map_tf_lookup(
Expand Down Expand Up @@ -173,39 +184,78 @@ def get_class_map_tf_lookup(
return table, image_mask

def get_namespace_map_tf_lookup(
self, mapping: Mapping
self, mapping: Mapping, keep_unknown: bool | None = None
) -> tf.lookup.StaticHashTable:
"""Create a tf.lookup.StaticHasTable for namespace mappings.
Args:
mapping: Mapping to apply.
keep_unknown: How to handle unknowns. If true, then unknown labels in the
class list are maintained as unknown in the mapped values. If false then
the unknown value is discarded. The default (`None`) will raise an error
if an unknown value is in the source classt list.
Returns:
A Tensorflow StaticHashTable and the image ClassList in the mapping's
target namespace.
Raises:
KeyError: If a class in not the mapping, or if the class list contains
an unknown token and `keep_unknown` was not specified.
"""
target_class_list = self.apply_namespace_mapping(mapping)
target_class_list = self.apply_namespace_mapping(
mapping, keep_unknown=keep_unknown
)
target_class_indices = {
k: i for i, k in enumerate(target_class_list.classes)
}
mapped_pairs = mapping.mapped_pairs
if keep_unknown:
mapped_pairs = mapped_pairs | {UNKNOWN_LABEL: UNKNOWN_LABEL}
keys = list(range(len(self.classes)))
values = [
target_class_indices[mapping.mapped_pairs[k]] for k in self.classes
target_class_indices[mapped_pairs[k]]
for k in self.classes
if k != UNKNOWN_LABEL or keep_unknown in (True, None)
]
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
default_value=-1,
)
return table

def apply_namespace_mapping(self, mapping: Mapping) -> ClassList:
def apply_namespace_mapping(
self, mapping: Mapping, keep_unknown: bool | None = None
) -> ClassList:
"""Apply a namespace mapping to this class list.
Args:
mapping: The mapping to apply.
keep_unknown: How to handle unknowns. If true, then unknown labels in the
class list are maintained as unknown in the mapped values. If false then
the unknown value is discarded. The default (`None`) will raise an error
if an unknown value is in the source classt list.
Returns:
A class list which is the result of applying the given mapping to this
class list.
Raises:
KeyError: If a class in not the mapping, or if the class list contains
an unknown token and `keep_unknown` was not specified.
"""
if mapping.source_namespace != self.namespace:
raise ValueError("mapping source namespace does not match class list's")
mapped_pairs = mapping.mapped_pairs
if keep_unknown:
mapped_pairs = mapped_pairs | {UNKNOWN_LABEL: UNKNOWN_LABEL}
return ClassList(
mapping.target_namespace,
tuple(
dict.fromkeys(
mapping.mapped_pairs[class_] for class_ in self.classes
mapped_pairs[class_]
for class_ in self.classes
if class_ != UNKNOWN_LABEL or keep_unknown in (True, None)
)
),
)
Expand Down
24 changes: 24 additions & 0 deletions chirp/tests/namespace_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

"""Tests for namespace_db."""

import io
import tempfile

from absl import logging
from chirp.taxonomy import namespace
from chirp.taxonomy import namespace_db
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -62,6 +66,26 @@ def test_class_maps(self):
table.lookup(tf.constant([i], dtype=tf.int64)).numpy()[0], 0
)

def test_class_map_csv(self):
cl = namespace.ClassList(
'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob')
)
cl_csv = cl.to_csv()
with io.StringIO(cl_csv) as f:
got_cl = namespace.ClassList.from_csv(f)
self.assertEqual(got_cl.namespace, 'ebird2021')
self.assertEqual(got_cl.classes, ('amecro', 'amegfi', 'amered', 'amerob'))

# Check that writing with tf.io.gfile behaves as expected, as newline
# behavior may be different than working with StringIO.
with tempfile.NamedTemporaryFile(suffix='.csv') as f:
with tf.io.gfile.GFile(f.name, 'w') as gf:
gf.write(cl_csv)
with open(f.name, 'r') as f:
got_cl = namespace.ClassList.from_csv(f.readlines())
self.assertEqual(got_cl.namespace, 'ebird2021')
self.assertEqual(got_cl.classes, ('amecro', 'amegfi', 'amered', 'amerob'))

def test_namespace_class_list_closure(self):
# Ensure that all classes in class lists appear in their namespace.
db = namespace_db.load_db()
Expand Down
6 changes: 6 additions & 0 deletions chirp/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from chirp.models import efficientnet
from chirp.models import frontend
from chirp.preprocessing import pipeline
from chirp.taxonomy import namespace
from chirp.tests import fake_dataset
from chirp.train import classifier
from clu import checkpoint
Expand Down Expand Up @@ -183,6 +184,11 @@ def test_export_model(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.train_dir, "label.csv"))
)
with open(os.path.join(self.train_dir, "label.csv")) as f:
got_class_list = namespace.ClassList.from_csv(f.readlines())
# Check equality of the ClassList with the Model Bundle.
self.assertEqual(model_bundle.class_lists["label"], got_class_list)

self.assertTrue(
tf.io.gfile.exists(
os.path.join(self.train_dir, "savedmodel/saved_model.pb")
Expand Down

0 comments on commit 6965f91

Please sign in to comment.