File tree Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Expand file tree Collapse file tree 1 file changed +4
-5
lines changed Original file line number Diff line number Diff line change 88The example code is based on https://github.com/deepmind/dm-haiku/blob/master/examples/mnist.py
99"""
1010
11- from __future__ import annotations
12-
13- from collections .abc import Generator
14- from collections .abc import Mapping
1511import os
1612from typing import Any
13+ from typing import Generator
14+ from typing import Mapping
15+ from typing import Tuple
1716import urllib
1817
1918import jax
@@ -120,7 +119,7 @@ def update(
120119 params : hk .Params ,
121120 opt_state : OptState ,
122121 batch : Batch ,
123- ) -> tuple [hk .Params , OptState ]:
122+ ) -> Tuple [hk .Params , OptState ]:
124123 """Learning rule (stochastic gradient descent)."""
125124 grads = jax .grad (loss )(params , batch )
126125 updates , opt_state = opt .update (grads , opt_state )
You can’t perform that action at this time.
0 commit comments