Skip to content

Commit

Permalink
Add handling of unknown labels to apply_namespace_mapping.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 551591557
  • Loading branch information
Bart van Merriënboer authored and copybara-github committed Jul 27, 2023
1 parent 5f9285c commit 65444ad
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
4 changes: 2 additions & 2 deletions chirp/preprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,10 +751,10 @@ def load_tables(
self.source_namespace + '_to_' + key
]
source_taxa_classes = source_class_list.apply_namespace_mapping(
namespace_mapping
namespace_mapping, keep_unknown=True
)
target_taxa_classes = target_classes.apply_namespace_mapping(
namespace_mapping
namespace_mapping, keep_unknown=True
)
namespace_table = source_class_list.get_namespace_map_tf_lookup(
namespace_mapping
Expand Down
28 changes: 26 additions & 2 deletions chirp/taxonomy/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,38 @@ def get_namespace_map_tf_lookup(
)
return table

def apply_namespace_mapping(self, mapping: Mapping) -> ClassList:
def apply_namespace_mapping(
self, mapping: Mapping, keep_unknown: bool | None = None
) -> ClassList:
"""Apply a namespace mapping to this class list.
Args:
mapping: The mapping to apply.
keep_unknown: How to handle unknowns. If true, then unknown labels in the
class list are maintained as unknown in the mapped values. If false then
the unknown value is discarded. The default (`None`) will raise an error
if an unknown value is in the source classt list.
Returns:
A class list which is the result of applying the given mapping to this
class list.
Raises:
KeyError: If a class in not the mapping, or if the class list contains
an unknown token and `keep_unknown` was not specified.
"""
if mapping.source_namespace != self.namespace:
raise ValueError("mapping source namespace does not match class list's")
mapped_pairs = mapping.mapped_pairs
if keep_unknown:
mapped_pairs = mapped_pairs | {UNKNOWN_LABEL: UNKNOWN_LABEL}
return ClassList(
mapping.target_namespace,
tuple(
dict.fromkeys(
mapping.mapped_pairs[class_] for class_ in self.classes
mapped_pairs[class_]
for class_ in self.classes
if class_ != UNKNOWN_LABEL or keep_unknown in (True, None)
)
),
)
Expand Down

0 comments on commit 65444ad

Please sign in to comment.