From 1c040df2c7df9a4a0db4ce097c101e1da21df313 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 17 Oct 2024 12:57:32 -0700 Subject: [PATCH] When extracting a taxonomy_model_tf classifier head, check more directly for variables location. PiperOrigin-RevId: 687014646 --- chirp/projects/zoo/taxonomy_model_tf.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/chirp/projects/zoo/taxonomy_model_tf.py b/chirp/projects/zoo/taxonomy_model_tf.py index bd3e260c..0685fdf5 100644 --- a/chirp/projects/zoo/taxonomy_model_tf.py +++ b/chirp/projects/zoo/taxonomy_model_tf.py @@ -107,11 +107,14 @@ def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': model_path = hub.resolve(model_url) class_lists_glob = (epath.Path(model_path) / 'assets').glob('*.csv') class_lists = cls.load_class_lists(class_lists_glob) + mutable_config = config.copy_and_resolve_references() + del mutable_config.model_path return cls( model=model, class_list=class_lists, batchable=batchable, - **config, + model_path=model_path, + **mutable_config, ) @classmethod @@ -174,12 +177,13 @@ def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF': def get_classifier_head(self, classes: list[str]): """Extract a classifier head for the desired subset of classes.""" - if self.tfhub_version is not None: - # This is a model loaded from TFHub. - # We need to extract the weights and biases from the saved model. + base_path = epath.Path(self.model_path) + if (base_path / 'variables').exists(): vars_filepath = f'{self.model_path}/variables/variables' - else: + elif (base_path / 'savedmodel' / 'variables').exists(): vars_filepath = f'{self.model_path}/savedmodel/variables/variables' + else: + raise ValueError(f'No variables found in {self.model_path}') def _get_weights_and_bias(num_classes: int): weights = None