|
39 | 39 | as_hashable_function, weakref_lru_cache, |
40 | 40 | partition_list) |
41 | 41 |
|
42 | | - |
43 | 42 | zip = safe_zip |
44 | 43 | map = safe_map |
45 | 44 | def identity(x): return x |
@@ -106,7 +105,29 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): |
106 | 105 | store.store(aux_primals) |
107 | 106 | return out_primals, out_tangents |
108 | 107 |
|
| 108 | +def direct_linearize(traceable, *primals, **kwargs): |
| 109 | + has_aux = kwargs.pop('has_aux', False) |
| 110 | + assert not has_aux |
| 111 | + with core.take_current_trace() as parent_trace: |
| 112 | + frame = pe.JaxprStackFrame() |
| 113 | + tangent_trace = pe.DynamicJaxprTrace(frame) |
| 114 | + tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval()) for p in primals] |
| 115 | + tag = core.TraceTag() |
| 116 | + linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag) |
| 117 | + tracers = [LinearizeTracer(linearize_trace, p, t) for p, t in zip(primals, tangents)] |
| 118 | + with core.set_current_trace(linearize_trace): |
| 119 | + ans = traceable.call_wrapped(*tracers) |
| 120 | + |
| 121 | + out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) |
| 122 | + out_tangents = map(tangent_trace.to_jaxpr_tracer, out_tangents) |
| 123 | + jaxpr, consts, attrs_tracked = frame.to_jaxpr(tangent_trace, out_tangents) |
| 124 | + out_tangents_pvals = [pe.PartialVal.unknown(core.get_aval(t)) for t in out_tangents] |
| 125 | + del attrs_tracked # TODO: attrs |
| 126 | + return out_primals, out_tangents_pvals, jaxpr, consts |
| 127 | + |
109 | 128 | def linearize(traceable, *primals, **kwargs): |
| 129 | + if config.use_direct_linearize.value: |
| 130 | + return direct_linearize(traceable, *primals, **kwargs) |
110 | 131 | has_aux = kwargs.pop('has_aux', False) |
111 | 132 | if not has_aux: |
112 | 133 | jvpfun = jvp(traceable) |
@@ -444,15 +465,89 @@ def _primal_tangent_shapes_match(primal, tangent): |
444 | 465 | call_param_updaters: dict[core.Primitive, Callable] = {} |
445 | 466 | call_transpose_param_updaters: dict[core.Primitive, Callable] = {} |
446 | 467 |
|
| 468 | +# -------------------- Linearize trace -------------------- |
| 469 | + |
| 470 | +class LinearizeTrace(Trace): |
| 471 | + |
| 472 | + def __init__(self, parent_trace, tangent_trace, tag): |
| 473 | + self.tag = tag |
| 474 | + self.parent_trace = parent_trace |
| 475 | + self.tangent_trace = tangent_trace |
| 476 | + |
| 477 | + def to_primal_tangent_pair(self, val): |
| 478 | + if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag: |
| 479 | + return (val.primal, val.tangent) |
| 480 | + else: |
| 481 | + tangent_zero = Zero.from_primal_value(val) |
| 482 | + return (val, tangent_zero) |
| 483 | + |
| 484 | + def process_primitive(self, primitive, args, params): |
| 485 | + primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, args)) |
| 486 | + if all(type(t) is Zero for t in tangents_in): |
| 487 | + return primitive.bind_with_trace(self.parent_trace, primals_in, params) |
| 488 | + lin = primitive_linearizations.get(primitive) |
| 489 | + if lin is None: |
| 490 | + lin = partial(fallback_linearize_rule, primitive) |
| 491 | + with core.set_current_trace(self.parent_trace): |
| 492 | + primal_out, linearized = lin(*primals_in, **params) |
| 493 | + with core.set_current_trace(self.tangent_trace): |
| 494 | + tangent_out = linearized(*tangents_in) |
| 495 | + if primitive.multiple_results: |
| 496 | + return [maybe_linearize_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] |
| 497 | + else: |
| 498 | + return maybe_linearize_tracer(self, primal_out, tangent_out) |
| 499 | + |
| 500 | +def maybe_linearize_tracer(trace, primal, tangent): |
| 501 | + if type(tangent) is Zero: |
| 502 | + return primal |
| 503 | + else: |
| 504 | + return LinearizeTracer(trace, primal, tangent) |
| 505 | + |
| 506 | +def fallback_linearize_rule(prim, *args, **kwargs): |
| 507 | + def call_prim(*args_): |
| 508 | + return prim.bind(*args_, **kwargs) |
| 509 | + with config.use_direct_linearize(False): |
| 510 | + out_primals, out_tangents_pvals, jaxpr, consts, *_maybe_aux = linearize( |
| 511 | + lu.wrap_init(call_prim), *args, **kwargs) |
| 512 | + def linearized(*tangents): |
| 513 | + tangents_out = iter(core.eval_jaxpr(jaxpr, consts, *tangents)) |
| 514 | + full_out = [pval.get_known() if pval.is_known() else next(tangents_out) |
| 515 | + for pval in out_tangents_pvals] |
| 516 | + assert next(tangents_out, None) is None |
| 517 | + return full_out |
| 518 | + return out_primals, linearized |
| 519 | + |
| 520 | +class LinearizeTracer(Tracer): |
| 521 | + __slots__ = ['primal', 'tangent'] |
| 522 | + |
| 523 | + def __init__(self, trace, primal, tangent): |
| 524 | + if config.enable_checks.value: |
| 525 | + _primal_tangent_shapes_match(primal, tangent) |
| 526 | + self._trace = trace |
| 527 | + self.primal = primal |
| 528 | + self.tangent = tangent |
| 529 | + |
| 530 | + @property |
| 531 | + def aval(self): |
| 532 | + return get_aval(self.primal) |
| 533 | + |
| 534 | + def full_lower(self): |
| 535 | + if type(self.tangent) is Zero: |
| 536 | + return core.full_lower(self.primal) |
| 537 | + else: |
| 538 | + return self |
| 539 | + |
| 540 | + def to_concrete_value(self): |
| 541 | + return core.to_concrete_value(self.primal) |
| 542 | + |
447 | 543 |
|
448 | 544 | # -------------------- Primitives -------------------- |
449 | 545 |
|
450 | 546 | primitive_jvps : dict[core.Primitive, Callable] = {} |
451 | | - |
452 | 547 | primitive_transposes: dict[core.Primitive, Callable] = {} |
453 | 548 | # transpose rules that internally perform reductions over the given named axes |
454 | 549 | reducing_transposes: dict[core.Primitive, Callable] = {} |
455 | | - |
| 550 | +primitive_linearizations: dict[core.Primitive, Callable] = {} |
456 | 551 |
|
457 | 552 | def deflinear(primitive, transpose_rule): |
458 | 553 | primitive_jvps[primitive] = partial(linear_jvp, primitive) |
|
0 commit comments