diff --git a/perch_hoplite/taxonomy/namespace.py b/perch_hoplite/taxonomy/namespace.py index cdfa030..32af9fe 100644 --- a/perch_hoplite/taxonomy/namespace.py +++ b/perch_hoplite/taxonomy/namespace.py @@ -22,7 +22,6 @@ from typing import Iterable import numpy as np -import tensorflow as tf UNKNOWN_LABEL = "unknown" @@ -155,110 +154,6 @@ def to_csv(self) -> str: writer.writerow([class_]) return buffer.getvalue() - def get_class_map_tf_lookup( - self, target_class_list: ClassList - ) -> tuple[tf.lookup.StaticHashTable, tf.Tensor]: - """Create a static hash map for class indices. - - Create a lookup table for use in TF Datasets, for, eg, converting between - ClassList defined for a dataset to a ClassList used as model outputs. - Classes in the source ClassList which do not appear in the target_class_list - will be mapped to -1. It is recommended to drop these labels subsequently - with: tf.gather(x, tf.where(x >= 0)[:, 0]) - - Args: - target_class_list: Class list to target. - - Returns: - A tensorflow StaticHashTable and an indicator vector for the image of - the classlist mapping. - """ - if self.namespace != target_class_list.namespace: - raise ValueError("namespaces must match when creating a class map.") - intersection = set(self.classes) & set(target_class_list.classes) - intersection = sorted(tuple(intersection)) - keys = tuple(self.classes.index(c) for c in intersection) - values = tuple(target_class_list.classes.index(c) for c in intersection) - - table = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64), - default_value=-1, - ) - image_mask = tf.constant( - [k in self.classes for k in target_class_list.classes], - tf.int64, - ) - return table, image_mask - - def get_namespace_map_tf_lookup( - self, - mapping: Mapping, - keep_unknown: bool | None = None, - target_class_list: ClassList | None = None, - ) -> tf.lookup.StaticHashTable: - """Create a tf.lookup.StaticHasTable for namespace mappings. - - Args: - mapping: Mapping to apply. - keep_unknown: How to handle unknowns. If true, then unknown labels in the - class list are maintained as unknown in the mapped values. If false then - the unknown value is discarded. The default (`None`) will raise an error - if an unknown value is in the source classt list. - target_class_list: Optional class list for ordering of mapping output. If - not provided, a class list consisting of the alphabetized image set of - the mapping will be used. - - Returns: - A Tensorflow StaticHashTable and the image ClassList in the mapping's - target namespace. - - Raises: - ValueError: If 'unknown' label is in source classes and keep_unknown was - not specified. - ValueError: If a target class list was passed and the namespace of this - does not match the mapping target namespace. - """ - if UNKNOWN_LABEL in self.classes and keep_unknown is None: - raise ValueError( - "'unknown' found in source classes. Explicitly set keep_unknown to" - " True or False. Alternatively, remove 'unknown' from source classes" - ) - # If no target_class_list is passed, default to apply_namespace_mapping - if target_class_list is None: - target_class_list = self.apply_namespace_mapping( - mapping, keep_unknown=keep_unknown - ) - else: - if target_class_list.namespace != mapping.target_namespace: - raise ValueError( - f"target class list namespace ({target_class_list.namespace}) " - "does not match mapping target namespace " - f"({mapping.target_namespace})" - ) - # Now check if 'unknown' label present in target_class_list.classes - keep_unknown = keep_unknown and UNKNOWN_LABEL in target_class_list.classes - # Dict which maps classes to an index - target_class_indices = { - k: i for i, k in enumerate(target_class_list.classes) - } - # Add unknown to mapped pairs - mapped_pairs = mapping.mapped_pairs | {UNKNOWN_LABEL: UNKNOWN_LABEL} - # If keep unknown==False, set unknown index to -1 to discard unknowns - if not keep_unknown: - target_class_indices[UNKNOWN_LABEL] = -1 - # Get keys and values to be used in the lookup table - keys = list(range(len(self.classes))) - values = [ - target_class_indices[mapped_pairs[k]] - for k in self.classes - ] - # Create the static hash table. If a key doesnt exist, set as -1. - table = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64), - default_value=-1, - ) - return table - def apply_namespace_mapping( self, mapping: Mapping, keep_unknown: bool | None = None ) -> ClassList: diff --git a/perch_hoplite/taxonomy/namespace_db_test.py b/perch_hoplite/taxonomy/namespace_db_test.py index 14ce88e..8ec6795 100644 --- a/perch_hoplite/taxonomy/namespace_db_test.py +++ b/perch_hoplite/taxonomy/namespace_db_test.py @@ -16,14 +16,12 @@ """Tests for namespace_db.""" import io -import random import tempfile from absl import logging -import numpy as np +from etils import epath from perch_hoplite.taxonomy import namespace from perch_hoplite.taxonomy import namespace_db -import tensorflow as tf from absl.testing import absltest from absl.testing import parameterized @@ -55,19 +53,6 @@ def test_load_namespace_db(self): self.assertEqual(caples_orders.namespace, 'ebird2021_orders') self.assertLen(caples_orders.classes, 11) - def test_class_maps(self): - db = namespace_db.load_db() - caples_list = db.class_lists['caples'] - sierras_list = db.class_lists['sierra_nevadas'] - table, image_mask = caples_list.get_class_map_tf_lookup(sierras_list) - # The Caples list is a strict subset of the Sierras list. - self.assertLen(caples_list.classes, np.sum(image_mask)) - self.assertEqual(image_mask.shape, (len(sierras_list.classes),)) - for i in range(len(caples_list.classes)): - self.assertGreaterEqual( - table.lookup(tf.constant([i], dtype=tf.int64)).numpy()[0], 0 - ) - def test_class_map_csv(self): cl = namespace.ClassList( 'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob') @@ -81,7 +66,7 @@ def test_class_map_csv(self): # Check that writing with tf.io.gfile behaves as expected, as newline # behavior may be different than working with StringIO. with tempfile.NamedTemporaryFile(suffix='.csv') as f: - with tf.io.gfile.GFile(f.name, 'w') as gf: + with epath.Path(f.name).open(mode='w') as gf: gf.write(cl_csv) with open(f.name, 'r') as f: got_cl = namespace.ClassList.from_csv(f.readlines()) @@ -166,92 +151,6 @@ def test_taxonomic_mappings(self): self.assertEmpty(missing_families) self.assertEmpty(missing_orders) - def test_reef_label_converting(self): - """Test operations used in ConvertReefLabels class. - - Part 1: Get the index of a sample in source_classes and the corresponding - index from lookup table so that we can check the look up table returns - the right soundtype e.g 'bioph' for the label 'bioph_rattle_response'. - Part 2: Iterate over labels in a shuffled version of source_classes - and check if each label maps correctly to its expected sound type. - """ - # Set up - db = namespace_db.load_db() - mapping = db.mappings['reef_class_to_soundtype'] - source_classes = db.class_lists['all_reefs'] - target_classes = db.class_lists['all_reefs'] - soundtype_table = source_classes.get_namespace_map_tf_lookup( - mapping, target_class_list=target_classes, keep_unknown=True - ) - # Part 1 - test_labels = ['geoph_waves', 'bioph_rattle_response', 'anthrop_bomb'] - expected_results = ['geoph', 'bioph', 'anthrop'] - for test_label, expected_result in zip(test_labels, expected_results): - classlist_index = source_classes.classes.index(test_label) - lookup_index = soundtype_table.lookup( - tf.constant(classlist_index, dtype=tf.int64) - ).numpy() - lookup_label = target_classes.classes[lookup_index] - self.assertEqual(expected_result, lookup_label) - # Part 2 - shuffled_classes = list(source_classes.classes) - np.random.seed(42) - random.shuffle(shuffled_classes) - for label in shuffled_classes: - # Every reef label is prefixed with either 'bioph', 'geoph', 'anthrop' - prefix = label.split('_')[0] - # Now mirror Part 1, by checking label against the prefix - classlist_index = source_classes.classes.index(label) - lookup_index = soundtype_table.lookup( - tf.constant(classlist_index, dtype=tf.int64) - ).numpy() - lookup_label = target_classes.classes[lookup_index] - self.assertEqual(prefix, lookup_label) - - @parameterized.parameters(True, False, None) - def test_namespace_map_tf_lookup(self, keep_unknown): - source = namespace.ClassList( - 'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob', 'unknown') - ) - mapping = namespace.Mapping( - 'ebird2021', - 'ebird2021', - { - 'amecro': 'amered', - 'amegfi': 'amerob', - 'amered': 'amerob', - 'amerob': 'amerob', - }, - ) - if keep_unknown is None: - self.assertRaises( - ValueError, - source.get_namespace_map_tf_lookup, - mapping=mapping, - keep_unknown=keep_unknown, - ) - return - - output_class_list = source.apply_namespace_mapping( - mapping, keep_unknown=keep_unknown - ) - if keep_unknown: - expect_classes = ('amered', 'amerob', 'unknown') - else: - expect_classes = ('amered', 'amerob') - self.assertSequenceEqual(output_class_list.classes, expect_classes) - lookup = source.get_namespace_map_tf_lookup( - mapping, keep_unknown=keep_unknown - ) - got = lookup.lookup( - tf.constant(list(range(len(source.classes))), dtype=tf.int64) - ).numpy() - if keep_unknown: - expect_idxs = (0, 1, 1, 1, 2) - else: - expect_idxs = (0, 1, 1, 1, -1) - self.assertSequenceEqual(tuple(got), expect_idxs) - if __name__ == '__main__': absltest.main()