Skip to content

Commit

Permalink
When extracting a taxonomy_model_tf classifier head, check more direc…
Browse files Browse the repository at this point in the history
…tly for variables location.

PiperOrigin-RevId: 687014646
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 17, 2024
1 parent d57369f commit 1c040df
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions chirp/projects/zoo/taxonomy_model_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,14 @@ def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF':
model_path = hub.resolve(model_url)
class_lists_glob = (epath.Path(model_path) / 'assets').glob('*.csv')
class_lists = cls.load_class_lists(class_lists_glob)
mutable_config = config.copy_and_resolve_references()
del mutable_config.model_path
return cls(
model=model,
class_list=class_lists,
batchable=batchable,
**config,
model_path=model_path,
**mutable_config,
)

@classmethod
Expand Down Expand Up @@ -174,12 +177,13 @@ def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF':

def get_classifier_head(self, classes: list[str]):
"""Extract a classifier head for the desired subset of classes."""
if self.tfhub_version is not None:
# This is a model loaded from TFHub.
# We need to extract the weights and biases from the saved model.
base_path = epath.Path(self.model_path)
if (base_path / 'variables').exists():
vars_filepath = f'{self.model_path}/variables/variables'
else:
elif (base_path / 'savedmodel' / 'variables').exists():
vars_filepath = f'{self.model_path}/savedmodel/variables/variables'
else:
raise ValueError(f'No variables found in {self.model_path}')

def _get_weights_and_bias(num_classes: int):
weights = None
Expand Down

0 comments on commit 1c040df

Please sign in to comment.