Skip to content

Commit eab9026

Browse files
Merge pull request jax-ml#25004 from jax-ml:linearize-trace
PiperOrigin-RevId: 698438212
2 parents 8d84f28 + d0f17c0 commit eab9026

File tree

3 files changed

+123
-3
lines changed

3 files changed

+123
-3
lines changed

jax/_src/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def trace_context():
219219
threefry_partitionable.value,
220220
threefry_gpu_kernel_lowering.value,
221221
sharding_in_types.value,
222+
use_direct_linearize.value,
222223
softmax_custom_jvp.value,
223224
enable_memories.value,
224225
disable_jit.value,
@@ -263,6 +264,7 @@ def trace_context():
263264
threefry_partitionable.value,
264265
threefry_gpu_kernel_lowering.value,
265266
sharding_in_types.value,
267+
use_direct_linearize.value,
266268
softmax_custom_jvp.value,
267269
enable_memories.value,
268270
disable_jit.value,
@@ -983,6 +985,7 @@ class _GlobalExtraJitContext(NamedTuple):
983985
threefry_partitionable: bool = False
984986
threefry_gpu_kernel_lowering: bool = False
985987
sharding_in_types: bool = False
988+
use_direct_linearize: bool = False
986989
softmax_custom_jvp: bool = False
987990
xla_profile_version: int = 0
988991
pgle_profiling_runs: int = 0
@@ -1025,6 +1028,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10251028
threefry_partitionable: bool | None = None
10261029
threefry_gpu_kernel_lowering: bool | None = None
10271030
sharding_in_types: bool | None = None
1031+
use_direct_linearize: bool | None = None
10281032
softmax_custom_jvp: bool | None = None
10291033
xla_profile_version: int | None = None
10301034
pgle_profiling_runs: int | None = None
@@ -1318,6 +1322,12 @@ def _update_jax_memories_thread_local(val):
13181322
'avals have sharding on them.'),
13191323
include_in_jit_key=True)
13201324

1325+
use_direct_linearize = bool_state(
1326+
name='jax_use_direct_linearize',
1327+
default=False,
1328+
help=('Use direct linearization instead JVP followed by partial eval'),
1329+
include_in_jit_key=True)
1330+
13211331
data_dependent_tracing_fallback = bool_state(
13221332
name='jax_data_dependent_tracing_fallback',
13231333
default=False,

jax/_src/interpreters/ad.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
as_hashable_function, weakref_lru_cache,
4040
partition_list)
4141

42-
4342
zip = safe_zip
4443
map = safe_map
4544
def identity(x): return x
@@ -106,7 +105,29 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
106105
store.store(aux_primals)
107106
return out_primals, out_tangents
108107

108+
def direct_linearize(traceable, *primals, **kwargs):
109+
has_aux = kwargs.pop('has_aux', False)
110+
assert not has_aux
111+
with core.take_current_trace() as parent_trace:
112+
frame = pe.JaxprStackFrame()
113+
tangent_trace = pe.DynamicJaxprTrace(frame)
114+
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
115+
tag = core.TraceTag()
116+
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag)
117+
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
118+
with core.set_current_trace(linearize_trace):
119+
ans = traceable.call_wrapped(*tracers)
120+
121+
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
122+
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
123+
jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents)
124+
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
125+
del attrs_tracked # TODO: attrs
126+
return out_primals, out_tangents_pvals, jaxpr, consts
127+
109128
def linearize(traceable, *primals, **kwargs):
129+
if config.use_direct_linearize.value:
130+
return direct_linearize(traceable, *primals, **kwargs)
110131
has_aux = kwargs.pop('has_aux', False)
111132
if not has_aux:
112133
jvpfun = jvp(traceable)
@@ -444,15 +465,89 @@ def _primal_tangent_shapes_match(primal, tangent):
444465
call_param_updaters: dict[core.Primitive, Callable] = {}
445466
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
446467

468+
# -------------------- Linearize trace --------------------
469+
470+
class LinearizeTrace(Trace):
471+
472+
def __init__(self, parent_trace, tangent_trace, tag):
473+
self.tag = tag
474+
self.parent_trace = parent_trace
475+
self.tangent_trace = tangent_trace
476+
477+
def to_primal_tangent_pair(self, val):
478+
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
479+
return (val.primal, val.tangent)
480+
else:
481+
tangent_zero = Zero.from_primal_value(val)
482+
return (val, tangent_zero)
483+
484+
def process_primitive(self, primitive, args, params):
485+
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
486+
if all(type(t) is Zero for t in tangents_in):
487+
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
488+
lin = primitive_linearizations.get(primitive)
489+
if lin is None:
490+
lin = partial(fallback_linearize_rule, primitive)
491+
with core.set_current_trace(self.parent_trace):
492+
primal_out, linearized = lin(*primals_in, **params)
493+
with core.set_current_trace(self.tangent_trace):
494+
tangent_out = linearized(*tangents_in)
495+
if primitive.multiple_results:
496+
return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
497+
else:
498+
return maybe_linearize_tracer(self, primal_out, tangent_out)
499+
500+
def maybe_linearize_tracer(trace, primal, tangent):
501+
if type(tangent) is Zero:
502+
return primal
503+
else:
504+
return LinearizeTracer(trace, primal, tangent)
505+
506+
def fallback_linearize_rule(prim, *args, **kwargs):
507+
def call_prim(*args_):
508+
return prim.bind(*args_, **kwargs)
509+
with config.use_direct_linearize(False):
510+
out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
511+
lu.wrap_init(call_prim), *args, **kwargs)
512+
def linearized(*tangents):
513+
tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents))
514+
full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
515+
for pval in out_tangents_pvals]
516+
assert next(tangents_out, None) is None
517+
return full_out
518+
return out_primals, linearized
519+
520+
class LinearizeTracer(Tracer):
521+
__slots__ = ['primal', 'tangent']
522+
523+
def __init__(self, trace, primal, tangent):
524+
if config.enable_checks.value:
525+
_primal_tangent_shapes_match(primal, tangent)
526+
self._trace = trace
527+
self.primal = primal
528+
self.tangent = tangent
529+
530+
@property
531+
def aval(self):
532+
return get_aval(self.primal)
533+
534+
def full_lower(self):
535+
if type(self.tangent) is Zero:
536+
return core.full_lower(self.primal)
537+
else:
538+
return self
539+
540+
def to_concrete_value(self):
541+
return core.to_concrete_value(self.primal)
542+
447543

448544
# -------------------- Primitives --------------------
449545

450546
primitive_jvps : dict[core.Primitive, Callable] = {}
451-
452547
primitive_transposes: dict[core.Primitive, Callable] = {}
453548
# transpose rules that internally perform reductions over the given named axes
454549
reducing_transposes: dict[core.Primitive, Callable] = {}
455-
550+
primitive_linearizations: dict[core.Primitive, Callable] = {}
456551

457552
def deflinear(primitive, transpose_rule):
458553
primitive_jvps[primitive] = partial(linear_jvp, primitive)

tests/api_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4807,6 +4807,21 @@ def add_one_and_dupe(x: int) -> tuple[int, int]:
48074807
jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True)
48084808
jax.eval_shape(jit_add_one_dupe, 0) # don't crash
48094809

4810+
def test_use_direct_linearize(self):
4811+
4812+
def check_invariant_to_use_direct_linearize(f):
4813+
with config.use_direct_linearize(False):
4814+
ans1 = f()
4815+
with config.use_direct_linearize(True):
4816+
ans2 = f()
4817+
4818+
self.assertEqual(ans1, ans2)
4819+
4820+
def sin_of_sin(x):
4821+
return jnp.sin(jnp.sin(x))
4822+
4823+
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))
4824+
48104825

48114826
class RematTest(jtu.JaxTestCase):
48124827

0 commit comments

Comments
 (0)