Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ jobs:
sudo apt-get install libsndfile1 ffmpeg
pip install absl-py
pip install requests
pip install tensorflow-cpu
pip install tensorflow-hub
pip install git+https://github.com/google-research/perch-hoplite.git
pip install "perch_hoplite[tf] @ git+https://github.com/google-research/perch-hoplite.git"
- name: Test db with unittest
run: python -m unittest discover -s perch_hoplite/db/tests -p "*test.py"
- name: Test taxonomy with unittest
Expand Down
11 changes: 6 additions & 5 deletions perch_hoplite/taxonomy/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Iterable

import numpy as np
import tensorflow as tf

UNKNOWN_LABEL = "unknown"

Expand Down Expand Up @@ -155,9 +154,7 @@ def to_csv(self) -> str:
writer.writerow([class_])
return buffer.getvalue()

def get_class_map_tf_lookup(
self, target_class_list: ClassList
) -> tuple[tf.lookup.StaticHashTable, tf.Tensor]:
def get_class_map_tf_lookup(self, target_class_list: ClassList):
"""Create a static hash map for class indices.

Create a lookup table for use in TF Datasets, for, eg, converting between
Expand All @@ -173,6 +170,8 @@ def get_class_map_tf_lookup(
A tensorflow StaticHashTable and an indicator vector for the image of
the classlist mapping.
"""
import tensorflow as tf

if self.namespace != target_class_list.namespace:
raise ValueError("namespaces must match when creating a class map.")
intersection = set(self.classes) & set(target_class_list.classes)
Expand All @@ -195,7 +194,7 @@ def get_namespace_map_tf_lookup(
mapping: Mapping,
keep_unknown: bool | None = None,
target_class_list: ClassList | None = None,
) -> tf.lookup.StaticHashTable:
):
"""Create a tf.lookup.StaticHasTable for namespace mappings.

Args:
Expand All @@ -218,6 +217,8 @@ class list are maintained as unknown in the mapped values. If false then
ValueError: If a target class list was passed and the namespace of this
does not match the mapping target namespace.
"""
import tensorflow as tf

if UNKNOWN_LABEL in self.classes and keep_unknown is None:
raise ValueError(
"'unknown' found in source classes. Explicitly set keep_unknown to"
Expand Down
Loading
Loading