Skip to content

Commit

Permalink
Fix ClassList CSV write method.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551435006
  • Loading branch information
sdenton4 authored and copybara-github committed Jul 27, 2023
1 parent 5f9285c commit d21c131
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
2 changes: 1 addition & 1 deletion chirp/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __post_init__(self):

self.model = tf.saved_model.load(model_path)
with label_csv_path.open('r') as f:
self.class_list = namespace.ClassList.from_csv(f)
self.class_list = namespace.ClassList.from_csv(f.readlines())

# Check whether the model support polymorphic batch shape.
sig = self.model.signatures['serving_default']
Expand Down
7 changes: 4 additions & 3 deletions chirp/taxonomy/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def from_csv(cls, csv_data: Iterable[str]) -> "ClassList":
"""
reader = csv.reader(csv_data)
namespace = next(reader)[0]
classes = tuple(row[0].strip() for row in reader)
classes = tuple(row[0].strip() for row in reader if row)
return ClassList(namespace, classes)

def to_csv(self) -> str:
Expand All @@ -133,8 +133,9 @@ def to_csv(self) -> str:
"""
buffer = io.StringIO(newline="")
writer = csv.writer(buffer)
writer.writerow(self.namespace)
writer.writerows(self.classes)
writer.writerow([self.namespace])
for class_ in self.classes:
writer.writerow([class_])
return buffer.getvalue()

def get_class_map_tf_lookup(
Expand Down
24 changes: 24 additions & 0 deletions chirp/tests/namespace_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

"""Tests for namespace_db."""

import io
import tempfile

from absl import logging
from chirp.taxonomy import namespace
from chirp.taxonomy import namespace_db
import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -62,6 +66,26 @@ def test_class_maps(self):
table.lookup(tf.constant([i], dtype=tf.int64)).numpy()[0], 0
)

def test_class_map_csv(self):
cl = namespace.ClassList(
'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob')
)
cl_csv = cl.to_csv()
with io.StringIO(cl_csv) as f:
got_cl = namespace.ClassList.from_csv(f)
self.assertEqual(got_cl.namespace, 'ebird2021')
self.assertEqual(got_cl.classes, ('amecro', 'amegfi', 'amered', 'amerob'))

# Check that writing with tf.io.gfile behaves as expected, as newline
# behavior may be different than working with StringIO.
with tempfile.NamedTemporaryFile(suffix='.csv') as f:
with tf.io.gfile.GFile(f.name, 'w') as gf:
gf.write(cl_csv)
with open(f.name, 'r') as f:
got_cl = namespace.ClassList.from_csv(f.readlines())
self.assertEqual(got_cl.namespace, 'ebird2021')
self.assertEqual(got_cl.classes, ('amecro', 'amegfi', 'amered', 'amerob'))

def test_namespace_class_list_closure(self):
# Ensure that all classes in class lists appear in their namespace.
db = namespace_db.load_db()
Expand Down
6 changes: 6 additions & 0 deletions chirp/tests/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from chirp.models import efficientnet
from chirp.models import frontend
from chirp.preprocessing import pipeline
from chirp.taxonomy import namespace
from chirp.tests import fake_dataset
from chirp.train import classifier
from clu import checkpoint
Expand Down Expand Up @@ -183,6 +184,11 @@ def test_export_model(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(self.train_dir, "label.csv"))
)
with open(os.path.join(self.train_dir, "label.csv")) as f:
got_class_list = namespace.ClassList.from_csv(f.readlines())
# Check equality of the ClassList with the Model Bundle.
self.assertEqual(model_bundle.class_lists["label"], got_class_list)

self.assertTrue(
tf.io.gfile.exists(
os.path.join(self.train_dir, "savedmodel/saved_model.pb")
Expand Down

0 comments on commit d21c131

Please sign in to comment.