Skip to content

Commit d0f17c0

Browse files
committed
Make a direct linearize trace.
This is an alternative to doing JVP followed by partial eval. The linearize trace has two parent traces, one for the primal computation and one for the tangent computation. If we make the tangent trace a DynamicJaxprTrace then we get staged linearization. If we make it the same as the primal trace then we get primal and tangent computations occurring in step (JVP). This is a neat trick enabled by stackless which now lives up to its name. With two parent traces we have a tree of traces not a linked list stack. Primitive ops can have their own linearization rules but as a fallback we can derive a linearization rule for a single op using jvp/partial-eval. For now this is all under a flag, `use_direct_linearize`, but I'm hoping we can make this the default for linearize/grad. It should help with remat and AD through state which are awkward to express via partial eval.
1 parent a582df0 commit d0f17c0

File tree

3 files changed

+123
-3
lines changed

3 files changed

+123
-3
lines changed

jax/_src/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def trace_context():
219219
threefry_partitionable.value,
220220
threefry_gpu_kernel_lowering.value,
221221
sharding_in_types.value,
222+
use_direct_linearize.value,
222223
softmax_custom_jvp.value,
223224
enable_memories.value,
224225
disable_jit.value,
@@ -263,6 +264,7 @@ def trace_context():
263264
threefry_partitionable.value,
264265
threefry_gpu_kernel_lowering.value,
265266
sharding_in_types.value,
267+
use_direct_linearize.value,
266268
softmax_custom_jvp.value,
267269
enable_memories.value,
268270
disable_jit.value,
@@ -983,6 +985,7 @@ class _GlobalExtraJitContext(NamedTuple):
983985
threefry_partitionable: bool = False
984986
threefry_gpu_kernel_lowering: bool = False
985987
sharding_in_types: bool = False
988+
use_direct_linearize: bool = False
986989
softmax_custom_jvp: bool = False
987990
xla_profile_version: int = 0
988991
pgle_profiling_runs: int = 0
@@ -1025,6 +1028,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10251028
threefry_partitionable: bool | None = None
10261029
threefry_gpu_kernel_lowering: bool | None = None
10271030
sharding_in_types: bool | None = None
1031+
use_direct_linearize: bool | None = None
10281032
softmax_custom_jvp: bool | None = None
10291033
xla_profile_version: int | None = None
10301034
pgle_profiling_runs: int | None = None
@@ -1318,6 +1322,12 @@ def _update_jax_memories_thread_local(val):
13181322
'avals have sharding on them.'),
13191323
include_in_jit_key=True)
13201324

1325+
use_direct_linearize = bool_state(
1326+
name='jax_use_direct_linearize',
1327+
default=False,
1328+
help=('Use direct linearization instead JVP followed by partial eval'),
1329+
include_in_jit_key=True)
1330+
13211331
data_dependent_tracing_fallback = bool_state(
13221332
name='jax_data_dependent_tracing_fallback',
13231333
default=False,

jax/_src/interpreters/ad.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
as_hashable_function, weakref_lru_cache,
4040
partition_list)
4141

42-
4342
zip = safe_zip
4443
map = safe_map
4544
def identity(x): return x
@@ -106,7 +105,29 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
106105
store.store(aux_primals)
107106
return out_primals, out_tangents
108107

108+
def direct_linearize(traceable, *primals, **kwargs):
109+
has_aux = kwargs.pop('has_aux', False)
110+
assert not has_aux
111+
with core.take_current_trace() as parent_trace:
112+
frame = pe.JaxprStackFrame()
113+
tangent_trace = pe.DynamicJaxprTrace(frame)
114+
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals]
115+
tag = core.TraceTag()
116+
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag)
117+
tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)]
118+
with core.set_current_trace(linearize_trace):
119+
ans = traceable.call_wrapped(*tracers)
120+
121+
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
122+
out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents)
123+
jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents)
124+
out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents]
125+
del attrs_tracked # TODO: attrs
126+
return out_primals, out_tangents_pvals, jaxpr, consts
127+
109128
def linearize(traceable, *primals, **kwargs):
129+
if config.use_direct_linearize.value:
130+
return direct_linearize(traceable, *primals, **kwargs)
110131
has_aux = kwargs.pop('has_aux', False)
111132
if not has_aux:
112133
jvpfun = jvp(traceable)
@@ -444,15 +465,89 @@ def _primal_tangent_shapes_match(primal, tangent):
444465
call_param_updaters: dict[core.Primitive, Callable] = {}
445466
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
446467

468+
# -------------------- Linearize trace --------------------
469+
470+
class LinearizeTrace(Trace):
471+
472+
def __init__(self, parent_trace, tangent_trace, tag):
473+
self.tag = tag
474+
self.parent_trace = parent_trace
475+
self.tangent_trace = tangent_trace
476+
477+
def to_primal_tangent_pair(self, val):
478+
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
479+
return (val.primal, val.tangent)
480+
else:
481+
tangent_zero = Zero.from_primal_value(val)
482+
return (val, tangent_zero)
483+
484+
def process_primitive(self, primitive, args, params):
485+
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args))
486+
if all(type(t) is Zero for t in tangents_in):
487+
return primitive.bind_with_trace(self.parent_trace, primals_in, params)
488+
lin = primitive_linearizations.get(primitive)
489+
if lin is None:
490+
lin = partial(fallback_linearize_rule, primitive)
491+
with core.set_current_trace(self.parent_trace):
492+
primal_out, linearized = lin(*primals_in, **params)
493+
with core.set_current_trace(self.tangent_trace):
494+
tangent_out = linearized(*tangents_in)
495+
if primitive.multiple_results:
496+
return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
497+
else:
498+
return maybe_linearize_tracer(self, primal_out, tangent_out)
499+
500+
def maybe_linearize_tracer(trace, primal, tangent):
501+
if type(tangent) is Zero:
502+
return primal
503+
else:
504+
return LinearizeTracer(trace, primal, tangent)
505+
506+
def fallback_linearize_rule(prim, *args, **kwargs):
507+
def call_prim(*args_):
508+
return prim.bind(*args_, **kwargs)
509+
with config.use_direct_linearize(False):
510+
out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize(
511+
lu.wrap_init(call_prim), *args, **kwargs)
512+
def linearized(*tangents):
513+
tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents))
514+
full_out = [pval.get_known() if pval.is_known() else next(tangents_out)
515+
for pval in out_tangents_pvals]
516+
assert next(tangents_out, None) is None
517+
return full_out
518+
return out_primals, linearized
519+
520+
class LinearizeTracer(Tracer):
521+
__slots__ = ['primal', 'tangent']
522+
523+
def __init__(self, trace, primal, tangent):
524+
if config.enable_checks.value:
525+
_primal_tangent_shapes_match(primal, tangent)
526+
self._trace = trace
527+
self.primal = primal
528+
self.tangent = tangent
529+
530+
@property
531+
def aval(self):
532+
return get_aval(self.primal)
533+
534+
def full_lower(self):
535+
if type(self.tangent) is Zero:
536+
return core.full_lower(self.primal)
537+
else:
538+
return self
539+
540+
def to_concrete_value(self):
541+
return core.to_concrete_value(self.primal)
542+
447543

448544
# -------------------- Primitives --------------------
449545

450546
primitive_jvps : dict[core.Primitive, Callable] = {}
451-
452547
primitive_transposes: dict[core.Primitive, Callable] = {}
453548
# transpose rules that internally perform reductions over the given named axes
454549
reducing_transposes: dict[core.Primitive, Callable] = {}
455-
550+
primitive_linearizations: dict[core.Primitive, Callable] = {}
456551

457552
def deflinear(primitive, transpose_rule):
458553
primitive_jvps[primitive] = partial(linear_jvp, primitive)

tests/api_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4807,6 +4807,21 @@ def add_one_and_dupe(x: int) -> tuple[int, int]:
48074807
jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True)
48084808
jax.eval_shape(jit_add_one_dupe, 0) # don't crash
48094809

4810+
def test_use_direct_linearize(self):
4811+
4812+
def check_invariant_to_use_direct_linearize(f):
4813+
with config.use_direct_linearize(False):
4814+
ans1 = f()
4815+
with config.use_direct_linearize(True):
4816+
ans2 = f()
4817+
4818+
self.assertEqual(ans1, ans2)
4819+
4820+
def sin_of_sin(x):
4821+
return jnp.sin(jnp.sin(x))
4822+
4823+
check_invariant_to_use_direct_linearize(lambda: jax.grad(sin_of_sin)(1.0))
4824+
48104825

48114826
class RematTest(jtu.JaxTestCase):
48124827

0 commit comments

Comments
 (0)