Skip to content

Commit 823bdee

Browse files
dougalmcopybara-github
authored andcommitted
Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change.
PiperOrigin-RevId: 677843398
1 parent 4773949 commit 823bdee

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

haiku/_src/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,15 @@ class JaxTraceLevel(NamedTuple):
7070

7171
@classmethod
7272
def current(cls):
73-
# TODO(tomhennigan): Remove once a version of JAX is released incl PR#9423.
74-
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
75-
top_type = trace_stack[0].trace_type
76-
level = trace_stack[-1].level
77-
sublevel = jax_core.cur_sublevel()
78-
return JaxTraceLevel(opaque=(top_type, level, sublevel))
73+
if jax.__version_info__ <= (0, 4, 33):
74+
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
75+
top_type = trace_stack[0].trace_type
76+
level = trace_stack[-1].level
77+
sublevel = jax_core.cur_sublevel()
78+
return JaxTraceLevel(opaque=(top_type, level, sublevel))
79+
80+
ts = jax_core.get_opaque_trace_state(convention="haiku")
81+
return JaxTraceLevel(opaque=ts)
7982

8083
frame_ids = it.count()
8184

0 commit comments

Comments
 (0)