Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 0 additions & 105 deletions perch_hoplite/taxonomy/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Iterable

import numpy as np
import tensorflow as tf

UNKNOWN_LABEL = "unknown"

Expand Down Expand Up @@ -155,110 +154,6 @@ def to_csv(self) -> str:
writer.writerow([class_])
return buffer.getvalue()

def get_class_map_tf_lookup(
self, target_class_list: ClassList
) -> tuple[tf.lookup.StaticHashTable, tf.Tensor]:
"""Create a static hash map for class indices.

Create a lookup table for use in TF Datasets, for, eg, converting between
ClassList defined for a dataset to a ClassList used as model outputs.
Classes in the source ClassList which do not appear in the target_class_list
will be mapped to -1. It is recommended to drop these labels subsequently
with: tf.gather(x, tf.where(x >= 0)[:, 0])

Args:
target_class_list: Class list to target.

Returns:
A tensorflow StaticHashTable and an indicator vector for the image of
the classlist mapping.
"""
if self.namespace != target_class_list.namespace:
raise ValueError("namespaces must match when creating a class map.")
intersection = set(self.classes) & set(target_class_list.classes)
intersection = sorted(tuple(intersection))
keys = tuple(self.classes.index(c) for c in intersection)
values = tuple(target_class_list.classes.index(c) for c in intersection)

table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
default_value=-1,
)
image_mask = tf.constant(
[k in self.classes for k in target_class_list.classes],
tf.int64,
)
return table, image_mask

def get_namespace_map_tf_lookup(
self,
mapping: Mapping,
keep_unknown: bool | None = None,
target_class_list: ClassList | None = None,
) -> tf.lookup.StaticHashTable:
"""Create a tf.lookup.StaticHasTable for namespace mappings.

Args:
mapping: Mapping to apply.
keep_unknown: How to handle unknowns. If true, then unknown labels in the
class list are maintained as unknown in the mapped values. If false then
the unknown value is discarded. The default (`None`) will raise an error
if an unknown value is in the source classt list.
target_class_list: Optional class list for ordering of mapping output. If
not provided, a class list consisting of the alphabetized image set of
the mapping will be used.

Returns:
A Tensorflow StaticHashTable and the image ClassList in the mapping's
target namespace.

Raises:
ValueError: If 'unknown' label is in source classes and keep_unknown was
not specified.
ValueError: If a target class list was passed and the namespace of this
does not match the mapping target namespace.
"""
if UNKNOWN_LABEL in self.classes and keep_unknown is None:
raise ValueError(
"'unknown' found in source classes. Explicitly set keep_unknown to"
" True or False. Alternatively, remove 'unknown' from source classes"
)
# If no target_class_list is passed, default to apply_namespace_mapping
if target_class_list is None:
target_class_list = self.apply_namespace_mapping(
mapping, keep_unknown=keep_unknown
)
else:
if target_class_list.namespace != mapping.target_namespace:
raise ValueError(
f"target class list namespace ({target_class_list.namespace}) "
"does not match mapping target namespace "
f"({mapping.target_namespace})"
)
# Now check if 'unknown' label present in target_class_list.classes
keep_unknown = keep_unknown and UNKNOWN_LABEL in target_class_list.classes
# Dict which maps classes to an index
target_class_indices = {
k: i for i, k in enumerate(target_class_list.classes)
}
# Add unknown to mapped pairs
mapped_pairs = mapping.mapped_pairs | {UNKNOWN_LABEL: UNKNOWN_LABEL}
# If keep unknown==False, set unknown index to -1 to discard unknowns
if not keep_unknown:
target_class_indices[UNKNOWN_LABEL] = -1
# Get keys and values to be used in the lookup table
keys = list(range(len(self.classes)))
values = [
target_class_indices[mapped_pairs[k]]
for k in self.classes
]
# Create the static hash table. If a key doesnt exist, set as -1.
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys, values, tf.int64, tf.int64),
default_value=-1,
)
return table

def apply_namespace_mapping(
self, mapping: Mapping, keep_unknown: bool | None = None
) -> ClassList:
Expand Down
105 changes: 2 additions & 103 deletions perch_hoplite/taxonomy/namespace_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
"""Tests for namespace_db."""

import io
import random
import tempfile

from absl import logging
import numpy as np
from etils import epath
from perch_hoplite.taxonomy import namespace
from perch_hoplite.taxonomy import namespace_db
import tensorflow as tf

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -55,19 +53,6 @@ def test_load_namespace_db(self):
self.assertEqual(caples_orders.namespace, 'ebird2021_orders')
self.assertLen(caples_orders.classes, 11)

def test_class_maps(self):
db = namespace_db.load_db()
caples_list = db.class_lists['caples']
sierras_list = db.class_lists['sierra_nevadas']
table, image_mask = caples_list.get_class_map_tf_lookup(sierras_list)
# The Caples list is a strict subset of the Sierras list.
self.assertLen(caples_list.classes, np.sum(image_mask))
self.assertEqual(image_mask.shape, (len(sierras_list.classes),))
for i in range(len(caples_list.classes)):
self.assertGreaterEqual(
table.lookup(tf.constant([i], dtype=tf.int64)).numpy()[0], 0
)

def test_class_map_csv(self):
cl = namespace.ClassList(
'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob')
Expand All @@ -81,7 +66,7 @@ def test_class_map_csv(self):
# Check that writing with tf.io.gfile behaves as expected, as newline
# behavior may be different than working with StringIO.
with tempfile.NamedTemporaryFile(suffix='.csv') as f:
with tf.io.gfile.GFile(f.name, 'w') as gf:
with epath.Path(f.name).open(mode='w') as gf:
gf.write(cl_csv)
with open(f.name, 'r') as f:
got_cl = namespace.ClassList.from_csv(f.readlines())
Expand Down Expand Up @@ -166,92 +151,6 @@ def test_taxonomic_mappings(self):
self.assertEmpty(missing_families)
self.assertEmpty(missing_orders)

def test_reef_label_converting(self):
"""Test operations used in ConvertReefLabels class.

Part 1: Get the index of a sample in source_classes and the corresponding
index from lookup table so that we can check the look up table returns
the right soundtype e.g 'bioph' for the label 'bioph_rattle_response'.
Part 2: Iterate over labels in a shuffled version of source_classes
and check if each label maps correctly to its expected sound type.
"""
# Set up
db = namespace_db.load_db()
mapping = db.mappings['reef_class_to_soundtype']
source_classes = db.class_lists['all_reefs']
target_classes = db.class_lists['all_reefs']
soundtype_table = source_classes.get_namespace_map_tf_lookup(
mapping, target_class_list=target_classes, keep_unknown=True
)
# Part 1
test_labels = ['geoph_waves', 'bioph_rattle_response', 'anthrop_bomb']
expected_results = ['geoph', 'bioph', 'anthrop']
for test_label, expected_result in zip(test_labels, expected_results):
classlist_index = source_classes.classes.index(test_label)
lookup_index = soundtype_table.lookup(
tf.constant(classlist_index, dtype=tf.int64)
).numpy()
lookup_label = target_classes.classes[lookup_index]
self.assertEqual(expected_result, lookup_label)
# Part 2
shuffled_classes = list(source_classes.classes)
np.random.seed(42)
random.shuffle(shuffled_classes)
for label in shuffled_classes:
# Every reef label is prefixed with either 'bioph', 'geoph', 'anthrop'
prefix = label.split('_')[0]
# Now mirror Part 1, by checking label against the prefix
classlist_index = source_classes.classes.index(label)
lookup_index = soundtype_table.lookup(
tf.constant(classlist_index, dtype=tf.int64)
).numpy()
lookup_label = target_classes.classes[lookup_index]
self.assertEqual(prefix, lookup_label)

@parameterized.parameters(True, False, None)
def test_namespace_map_tf_lookup(self, keep_unknown):
source = namespace.ClassList(
'ebird2021', ('amecro', 'amegfi', 'amered', 'amerob', 'unknown')
)
mapping = namespace.Mapping(
'ebird2021',
'ebird2021',
{
'amecro': 'amered',
'amegfi': 'amerob',
'amered': 'amerob',
'amerob': 'amerob',
},
)
if keep_unknown is None:
self.assertRaises(
ValueError,
source.get_namespace_map_tf_lookup,
mapping=mapping,
keep_unknown=keep_unknown,
)
return

output_class_list = source.apply_namespace_mapping(
mapping, keep_unknown=keep_unknown
)
if keep_unknown:
expect_classes = ('amered', 'amerob', 'unknown')
else:
expect_classes = ('amered', 'amerob')
self.assertSequenceEqual(output_class_list.classes, expect_classes)
lookup = source.get_namespace_map_tf_lookup(
mapping, keep_unknown=keep_unknown
)
got = lookup.lookup(
tf.constant(list(range(len(source.classes))), dtype=tf.int64)
).numpy()
if keep_unknown:
expect_idxs = (0, 1, 1, 1, 2)
else:
expect_idxs = (0, 1, 1, 1, -1)
self.assertSequenceEqual(tuple(got), expect_idxs)


if __name__ == '__main__':
absltest.main()