Skip to content

Commit 435b832

Browse files
Angelogebseanprime7
authored andcommitted
Fix handling of literals for loop
- d55f103e9518b6394ae4879c408dd055e44e3e04 by Anxhelo Xhebraj <axhebraj@nvidia.com> Signed-off-by: Anxhelo Xhebraj <axhebraj@nvidia.com> GitOrigin-RevId: d55f103e9518b6394ae4879c408dd055e44e3e04
1 parent cdadca7 commit 435b832

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

src/jaxpp/training.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
import jax._src.core as jcore
2323
import jax.api_util as jau
2424
import jax.extend.linear_util as lu
25-
import jax.numpy as jnp
2625
import numpy as np
2726
from jax._src import dtypes
27+
28+
# TODO: maybe use custom definition of `tree_broadcast` and
29+
# improve error message if it cannot be broadcasted
30+
from jax._src.custom_transpose import tree_broadcast
2831
from jax.interpreters import ad
2932
from jax.interpreters import partial_eval as pe
3033

@@ -128,6 +131,7 @@ class Concat:
128131
axis: int = 0
129132

130133
def state(self, n: int, a: jax.ShapeDtypeStruct) -> jax.Array:
134+
assert n > 0
131135
shape = a.shape[: self.axis] + (n,) + a.shape[self.axis :]
132136
return jax.numpy.zeros(shape, dtype=a.dtype)
133137

@@ -218,14 +222,20 @@ def treduce(fun, xs, operation=(Concat(), Add())):
218222

219223
@functools.wraps(fun)
220224
def wrap(i):
221-
e = jax.tree.map(lambda x: jnp.take(x, i, axis=axis), xs)
225+
e = jax.tree.map(lambda x: jax.numpy.take(x, i, axis=axis), xs)
222226
return fun(e)
223227

224228
return treduce_i(
225229
wrap, first_batch_shape[axis], schedule=schedule, operation=operation
226230
)
227231

228232

233+
def copy_if_scalar(x: jax.Array) -> jax.Array:
234+
if x.ndim == 0:
235+
return jax.numpy.array(x, copy=True)
236+
return x
237+
238+
229239
def treduce_i(
230240
fun: Callable[[int], Y], length: int, schedule: BaseSchedule, operation=default_op
231241
) -> Y:
@@ -264,25 +274,19 @@ def treduce_i(fun, length, operation):
264274
structure as ``Y``.
265275
"""
266276
with log_elapsed_time("jaxpr/first_loop_tracing"), yield_scope():
267-
body_args = jcore.ShapedArray((), dtype=jnp.int32)
268-
vmapped_jaxpr, loop_out_shapes = jax.make_jaxpr(fun, return_shape=True)(
269-
body_args
270-
)
271-
272-
# TODO: maybe use custom definition of `tree_broadcast` and
273-
# improve error message if it cannot be broadcasted
274-
from jax._src.custom_transpose import tree_broadcast
277+
body_args = jcore.ShapedArray((), dtype=jax.numpy.int32)
278+
body_jaxpr, loop_out_shapes = jax.make_jaxpr(fun, return_shape=True)(body_args)
275279

276280
operation = tree_broadcast(jax.tree_util.tree_structure(loop_out_shapes), operation)
277281

278282
def state(op: Op, a):
279-
return op.state(length, a)
283+
return copy_if_scalar(op.state(length, a))
280284

281285
loop_state = jax.tree_util.tree_map(state, operation, loop_out_shapes)
282286

283287
def _fun(mubatch_idx, loop_state):
284288
def update(op: Op, state, update):
285-
return op.update(state, update, mubatch_idx)
289+
return copy_if_scalar(op.update(state, update, mubatch_idx))
286290

287291
return (
288292
mubatch_idx + 1,
@@ -293,8 +297,8 @@ def update(op: Op, state, update):
293297
jax.tree.unflatten(
294298
jax.tree.structure(loop_out_shapes),
295299
jcore.eval_jaxpr(
296-
vmapped_jaxpr.jaxpr,
297-
vmapped_jaxpr.consts,
300+
body_jaxpr.jaxpr,
301+
body_jaxpr.consts,
298302
mubatch_idx,
299303
propagate_source_info=False,
300304
),
@@ -303,10 +307,10 @@ def update(op: Op, state, update):
303307
)
304308

305309
debug_info = jau.debug_info(treduce_i.__name__, fun, (body_args,), {})
306-
wrapped_vmapped_fun = lu.wrap_init(_fun, debug_info=debug_info)
310+
wrapped_body_fun = lu.wrap_init(_fun, debug_info=debug_info)
307311
with log_elapsed_time("jaxpr/second_loop_tracing"), yield_scope():
308312
loop_output = pscan_wrapped(
309-
wrapped_vmapped_fun, loop_state, length=length, schedule=schedule
313+
wrapped_body_fun, loop_state, length=length, schedule=schedule
310314
)
311315

312316
return loop_output

tests/test_transformations.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
from jax.scipy.special import logsumexp
1212

1313
from jaxpp import env_vars
14-
from jaxpp.api import BaseSchedule, pipeline_enter_stage, treduce
14+
from jaxpp.api import Add, BaseSchedule, Concat, pipeline_enter_stage, treduce
1515
from jaxpp.core import (
1616
cluster_jaxpr,
17+
common_passes,
18+
fixup_multidefs,
1719
infer_shardings2,
1820
maybe_unroll_loop,
21+
outvar_normalization,
1922
strip_inspect_sharding_eqns,
2023
wrap_into_tasks,
2124
)
@@ -213,6 +216,42 @@ def test_transformations_dont_fail(
213216
)
214217

215218

219+
def test_literal_via_constant_folding():
220+
"""Test if constant folding or other optimizations might introduce literals."""
221+
num_stages = 1
222+
n_mubatches = 2
223+
schedule = Std1F1B(num_stages=1)
224+
(params, X, Y), stage_mesh, replicated_sharding = get_context(
225+
num_stages, n_mubatches
226+
)
227+
228+
def grads_with_constant_output(params, data):
229+
loss, grad = grads(params, data)
230+
identity = jnp.array(1.0)
231+
return (loss, grad, identity)
232+
233+
custom_op = (Concat(), Add, Add)
234+
235+
total_grads_fn = jax.jit(
236+
lambda params, X, Y: treduce(
237+
partial(grads_with_constant_output, params),
238+
(X, Y),
239+
schedule=schedule,
240+
operation=custom_op,
241+
)
242+
)
243+
244+
cjaxpr = total_grads_fn.trace(params, X, Y).jaxpr
245+
with env_vars.jaxpp_conservative_loop_clustering.set(False):
246+
scheduled_jaxpr = get_scheduled_jaxpr(
247+
cjaxpr,
248+
1,
249+
stage_mesh,
250+
replicated_sharding,
251+
skip_propagation=True,
252+
)
253+
254+
216255
def test_skip_propagation_false():
217256
test_equivalence_scheduled(
218257
*(3, 3, 5, Eager1F1B(num_stages=3)), skip_propagation=False

0 commit comments

Comments
 (0)