2222import jax ._src .core as jcore
2323import jax .api_util as jau
2424import jax .extend .linear_util as lu
25- import jax .numpy as jnp
2625import numpy as np
2726from jax ._src import dtypes
27+
28+ # TODO: maybe use custom definition of `tree_broadcast` and
29+ # improve error message if it cannot be broadcasted
30+ from jax ._src .custom_transpose import tree_broadcast
2831from jax .interpreters import ad
2932from jax .interpreters import partial_eval as pe
3033
@@ -128,6 +131,7 @@ class Concat:
128131 axis : int = 0
129132
130133 def state (self , n : int , a : jax .ShapeDtypeStruct ) -> jax .Array :
134+ assert n > 0
131135 shape = a .shape [: self .axis ] + (n ,) + a .shape [self .axis :]
132136 return jax .numpy .zeros (shape , dtype = a .dtype )
133137
@@ -218,14 +222,20 @@ def treduce(fun, xs, operation=(Concat(), Add())):
218222
219223 @functools .wraps (fun )
220224 def wrap (i ):
221- e = jax .tree .map (lambda x : jnp .take (x , i , axis = axis ), xs )
225+ e = jax .tree .map (lambda x : jax . numpy .take (x , i , axis = axis ), xs )
222226 return fun (e )
223227
224228 return treduce_i (
225229 wrap , first_batch_shape [axis ], schedule = schedule , operation = operation
226230 )
227231
228232
233+ def copy_if_scalar (x : jax .Array ) -> jax .Array :
234+ if x .ndim == 0 :
235+ return jax .numpy .array (x , copy = True )
236+ return x
237+
238+
229239def treduce_i (
230240 fun : Callable [[int ], Y ], length : int , schedule : BaseSchedule , operation = default_op
231241) -> Y :
@@ -264,25 +274,19 @@ def treduce_i(fun, length, operation):
264274 structure as ``Y``.
265275 """
266276 with log_elapsed_time ("jaxpr/first_loop_tracing" ), yield_scope ():
267- body_args = jcore .ShapedArray ((), dtype = jnp .int32 )
268- vmapped_jaxpr , loop_out_shapes = jax .make_jaxpr (fun , return_shape = True )(
269- body_args
270- )
271-
272- # TODO: maybe use custom definition of `tree_broadcast` and
273- # improve error message if it cannot be broadcasted
274- from jax ._src .custom_transpose import tree_broadcast
277+ body_args = jcore .ShapedArray ((), dtype = jax .numpy .int32 )
278+ body_jaxpr , loop_out_shapes = jax .make_jaxpr (fun , return_shape = True )(body_args )
275279
276280 operation = tree_broadcast (jax .tree_util .tree_structure (loop_out_shapes ), operation )
277281
278282 def state (op : Op , a ):
279- return op .state (length , a )
283+ return copy_if_scalar ( op .state (length , a ) )
280284
281285 loop_state = jax .tree_util .tree_map (state , operation , loop_out_shapes )
282286
283287 def _fun (mubatch_idx , loop_state ):
284288 def update (op : Op , state , update ):
285- return op .update (state , update , mubatch_idx )
289+ return copy_if_scalar ( op .update (state , update , mubatch_idx ) )
286290
287291 return (
288292 mubatch_idx + 1 ,
@@ -293,8 +297,8 @@ def update(op: Op, state, update):
293297 jax .tree .unflatten (
294298 jax .tree .structure (loop_out_shapes ),
295299 jcore .eval_jaxpr (
296- vmapped_jaxpr .jaxpr ,
297- vmapped_jaxpr .consts ,
300+ body_jaxpr .jaxpr ,
301+ body_jaxpr .consts ,
298302 mubatch_idx ,
299303 propagate_source_info = False ,
300304 ),
@@ -303,10 +307,10 @@ def update(op: Op, state, update):
303307 )
304308
305309 debug_info = jau .debug_info (treduce_i .__name__ , fun , (body_args ,), {})
306- wrapped_vmapped_fun = lu .wrap_init (_fun , debug_info = debug_info )
310+ wrapped_body_fun = lu .wrap_init (_fun , debug_info = debug_info )
307311 with log_elapsed_time ("jaxpr/second_loop_tracing" ), yield_scope ():
308312 loop_output = pscan_wrapped (
309- wrapped_vmapped_fun , loop_state , length = length , schedule = schedule
313+ wrapped_body_fun , loop_state , length = length , schedule = schedule
310314 )
311315
312316 return loop_output
0 commit comments