Skip to content

Commit 00d27d4

Browse files
committed
import __future__.annotations to keep up with haiku==0.0.14
1 parent e338d84 commit 00d27d4

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

haiku/haiku_simple.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
The example code is based on https://github.com/deepmind/dm-haiku/blob/master/examples/mnist.py
99
"""
1010

11+
from __future__ import annotations
12+
13+
from collections.abc import Generator
14+
from collections.abc import Mapping
1115
import os
1216
from typing import Any
13-
from typing import Generator
14-
from typing import Mapping
15-
from typing import Tuple
1617
import urllib
1718

1819
import jax
@@ -119,7 +120,7 @@ def update(
119120
params: hk.Params,
120121
opt_state: OptState,
121122
batch: Batch,
122-
) -> Tuple[hk.Params, OptState]:
123+
) -> tuple[hk.Params, OptState]:
123124
"""Learning rule (stochastic gradient descent)."""
124125
grads = jax.grad(loss)(params, batch)
125126
updates, opt_state = opt.update(grads, opt_state)

0 commit comments

Comments
 (0)