diff --git a/chirp/preprocessing/pipeline.py b/chirp/preprocessing/pipeline.py index 8819f98f..f6794773 100644 --- a/chirp/preprocessing/pipeline.py +++ b/chirp/preprocessing/pipeline.py @@ -751,10 +751,10 @@ def load_tables( self.source_namespace + '_to_' + key ] source_taxa_classes = source_class_list.apply_namespace_mapping( - namespace_mapping + namespace_mapping, keep_unknown=True ) target_taxa_classes = target_classes.apply_namespace_mapping( - namespace_mapping + namespace_mapping, keep_unknown=True ) namespace_table = source_class_list.get_namespace_map_tf_lookup( namespace_mapping diff --git a/chirp/taxonomy/namespace.py b/chirp/taxonomy/namespace.py index f61b64fd..9adcefe9 100644 --- a/chirp/taxonomy/namespace.py +++ b/chirp/taxonomy/namespace.py @@ -198,14 +198,38 @@ def get_namespace_map_tf_lookup( ) return table - def apply_namespace_mapping(self, mapping: Mapping) -> ClassList: + def apply_namespace_mapping( + self, mapping: Mapping, keep_unknown: bool | None = None + ) -> ClassList: + """Apply a namespace mapping to this class list. + + Args: + mapping: The mapping to apply. + keep_unknown: How to handle unknowns. If true, then unknown labels in the + class list are maintained as unknown in the mapped values. If false then + the unknown value is discarded. The default (`None`) will raise an error + if an unknown value is in the source classt list. + + Returns: + A class list which is the result of applying the given mapping to this + class list. + + Raises: + KeyError: If a class in not the mapping, or if the class list contains + an unknown token and `keep_unknown` was not specified. + """ if mapping.source_namespace != self.namespace: raise ValueError("mapping source namespace does not match class list's") + mapped_pairs = mapping.mapped_pairs + if keep_unknown: + mapped_pairs = mapped_pairs | {UNKNOWN_LABEL: UNKNOWN_LABEL} return ClassList( mapping.target_namespace, tuple( dict.fromkeys( - mapping.mapped_pairs[class_] for class_ in self.classes + mapped_pairs[class_] + for class_ in self.classes + if class_ != UNKNOWN_LABEL or keep_unknown in (True, None) ) ), )