File tree Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Expand file tree Collapse file tree 1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change 88The example code is based on https://github.com/deepmind/dm-haiku/blob/master/examples/mnist.py
99"""
1010
11+ from __future__ import annotations
12+
13+ from collections .abc import Generator
14+ from collections .abc import Mapping
1115import os
1216from typing import Any
13- from typing import Generator
14- from typing import Mapping
15- from typing import Tuple
1617import urllib
1718
1819import jax
@@ -119,7 +120,7 @@ def update(
119120 params : hk .Params ,
120121 opt_state : OptState ,
121122 batch : Batch ,
122- ) -> Tuple [hk .Params , OptState ]:
123+ ) -> tuple [hk .Params , OptState ]:
123124 """Learning rule (stochastic gradient descent)."""
124125 grads = jax .grad (loss )(params , batch )
125126 updates , opt_state = opt .update (grads , opt_state )
You can’t perform that action at this time.
0 commit comments