-
Notifications
You must be signed in to change notification settings - Fork 5
Description
This is probably hard and even potentially impossible without foregoing the horde-ad's relatively simple semantics based on generalized dual numbers. E.g., maybe this can only be done in the far future when horde-ad is somehow fused with (or into) CHAD. However, it's worth exploring whether the addition is not easier than expected.
Below are snippets from a chat with @tomsmeding, who has ideas how this could potentially be solved, based on his experience with adding operations to CHAD.
Tom's brainstorming, first approach:
with a mix of these methods I now suspect it's possible to allow mapAccum to have an open combination function
the idea would be: use the same technique as this old build1 code, but cleverer: in the mapAccum reverse pass, run one function invocation on the main Delta structure (thus extending it with new nodes), then (important new bit!) backpropagate (i.e. evaluate) until the highest ID on the tape before this mapAccum-reverse-pass. Then run the next function invocation, then this partial backpropagation again. Etc. until you're done
instead of expanding the full tiny-Delta graph of the mapAccum at once and then evaluating it, you expand it bit-by-bit and evaluate it bit-by-bit
the old build code expanded the graph all at once
your current mapAccum code does the bit-by-bit thing, but on separate Delta graphs meaning that no contributions can be recorded to "free variables" of the mapAccum operation
Tom's brainstorming, second approach:
it's not obvious from the syntax maybe, but the fact that you can symbolically evaluate the vjp of the combination function to an AST with a separated-out Delta does encode the fact that the Deltas are equally structured
if the combination function is allowed to be open, then I think you do still need to keep its little Delta graph separate from the Delta graph of the full program
if you don't have a barrier between those two and the nodes just end up in the big nMap, you have no choice but to unroll
so you'd have to somehow symbolically differentiate the open combination function into an open artifact
and then embed that open artifact in the AST that comes out of differentiation
Delta evaluation would then receive the DMapAccumLDer with that open artifact inside, and would have to turn that into a mapAccumR that does roughly what the current evaluation of DMapAccumLDer does, but also collecting the contributions for the free variables of that open artifact
said collection needs to be done in a bulk way, perhaps the free-variable cotangent contributions need to go into the array output channel of the mapAccumR, and then be summed together afterwards with a bulk (cotangent) sum
those summed free variable cotangents would then, finally, be added to the dMap in the evalstate
if all of this works, then the result may well be something that restages properly