-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
The quality of reverse mode AD is dependent on how the input type is associates i.e. differentiating w.r.t x : X × (Y × Z) produces better result then w.r.t. x : (X × Y) × Z.
For example, this produces the right derivative
HasRevFDerivUpdate Float (fun (x : Float^[10] × (Float^[10] × Float)) => x.2.1[i]) _
but
HasRevFDerivUpdate Float (fun (x : (Float^[10] × Float^[10]) × Float) => x.1.2[i]) _
will produce code with setElem 0 i dy True.intro which is bad.
Either improve domain projection or change how comp_rule works for HasRevFDerivUpdate. Right now it is
theorem comp_rule (g : X → Y) (f : Y → Z) {g' f'}
(hg : HasRevFDerivUpdate K g g') (hf : HasRevFDeriv K f f') :
HasRevFDerivUpdate K
(fun x => f (g x))
(fun x =>
let' (y,dg') := g' x
let' (z,df') := f' y
(z, fun dz dx =>
let dy := df' dz
let dx := dg' dy dx
dx)) := by ...
which is bad if f is projection! Maybe have some special rule when f is projection would be nice! This should nicely generalize to custom structures.
Domain projection is still good to have to reduce the growing context introduced by let_rule.
Metadata
Metadata
Assignees
Labels
No labels