Skip to content

Commit 69df162

Browse files
Haiku Contributorcopybara-github
authored andcommitted
Fix incorrect type specification for Haiku partition function.
PiperOrigin-RevId: 770315759
1 parent 8bef47e commit 69df162

2 files changed

Lines changed: 2 additions & 4 deletions

File tree

haiku/_src/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,6 @@ hk_py_library(
925925
deps = [
926926
":data_structures",
927927
":utils",
928-
# pip: jax
929928
],
930929
)
931930

haiku/_src/filtering.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from haiku._src import data_structures
2222
from haiku._src import utils
23-
import jax
2423

2524
T = TypeVar("T")
2625
InT = TypeVar("InT")
@@ -48,7 +47,7 @@ def traverse(
4847

4948

5049
def partition(
51-
predicate: Callable[[str, str, jax.Array], bool],
50+
predicate: Callable[[str, str, T], bool],
5251
structure: Mapping[str, Mapping[str, T]],
5352
) -> tuple[Mapping[str, Mapping[str, T]], Mapping[str, Mapping[str, T]]]:
5453
"""Partitions the input structure in two according to a given predicate.
@@ -77,7 +76,7 @@ def partition(
7776
predicate. Entries matching the predicate will be in the first structure,
7877
and the rest will be in the second.
7978
"""
80-
f = lambda m, n, v: int(not predicate(m, n, v))
79+
f: Callable[[str, str, T], int] = lambda m, n, v: int(not predicate(m, n, v))
8180
return partition_n(f, structure, 2)
8281

8382

0 commit comments

Comments
 (0)