1+ module Data.Autodiff.Ops
2+
3+ import Data.Tensor
4+ import Data.Container.Additive
5+ import Data.Para
6+ import Control.Monad.Distribution
7+ import Control.Monad.Identity
8+ import Data.ComMonoid
9+
10+ import Misc
11+
12+ %hide Data . Container . Base . Morphism . Definition . DependentLenses . (=%> )
13+ %hide Syntax . WithProof . prefix. (@@)
14+
15+ -- This is here and not in Container.Additive because it uses `Tensor`
16+ public export
17+ Simplex : Nat -> AddCont
18+ Simplex n = MkAddCont $ (_ : Dist n) !> (Tensor [n] Double)
19+
20+ public export
21+ MulParametric : {a : Type } -> Num a => ParaAddDLens (Const a) (Const a)
22+ MulParametric = binaryOpToPara {p= Const a} Mul
23+
24+ public export
25+ AddParametric : {a : Type } -> Num a => ParaAddDLens (Const a) (Const a)
26+ AddParametric = binaryOpToPara {p= Const a} Sum
27+
28+ public export
29+ AffineParametric : {a : Type } -> Num a => ParaAddDLens (Const a) (Const a)
30+ AffineParametric = composePara MulParametric AddParametric
31+
32+ public export
33+ LeakyReLU : {a : Type } -> Num a => Ord a => FromDouble a =>
34+ {default 0.01 alpha : a} ->
35+ ParaAddDLens (Const a) (Const a)
36+ LeakyReLU = trivialParam (!%+ \ x =>
37+ (if x > 0 then x else alpha * x ** \ x' => if x > 0 then x' else alpha))
38+
39+ public export
40+ LeakyReLUTensor : {a : Type } -> {n : Nat } -> Num a => Ord a => FromDouble a =>
41+ {default 0.01 alpha : a} ->
42+ ParaAddDLens (Const (Tensor [n] a)) (Const (Tensor [n] a))
43+ LeakyReLUTensor = trivialParam (!%+ \ x =>
44+ (x <&> (\ xx => if xx > 0 then xx else alpha * xx) ** \ dy =>
45+ (\ (d, xx) => if xx > 0 then d else alpha * d) <$> liftA2Tensor dy x))
46+
47+ public export
48+ parallelTensor2 : {a, b: Type } -> Num a => Num b =>
49+ ParaAddDLens (Const a) (Const b) ->
50+ ParaAddDLens (Const (Tensor [2] a)) (Const (Tensor [2] b))
51+ parallelTensor2 (MkPara pCont f) = MkPara
52+ (pCont >< pCont)
53+ (!%+ \ (x, (p, q)) =>
54+ let (b1 ** kf) = (%! ) f (x @@ [0 ], p)
55+ (b2 ** kg) = (%! ) f (x @@ [1 ], q)
56+ in (> # [b1, b2] ** \ bs' =>
57+ let (x1', p') = kf (bs' @@ [0 ])
58+ (x2', q') = kg (bs' @@ [1 ])
59+ in (> # [x1', x2'], (p', q'))))
60+
61+ public export
62+ parallelTensor3 : {a, b : Type } -> Num a => Num b =>
63+ ParaAddDLens (Const a) (Const b) ->
64+ ParaAddDLens (Const (Tensor [3] a)) (Const (Tensor [3] b))
65+ parallelTensor3 (MkPara pCont f) = MkPara
66+ (pCont >< pCont >< pCont)
67+ (!%+ \ (x, (p, q, r)) =>
68+ let (b1 ** kf) = (%! ) f (x @@ [0 ], p)
69+ (b2 ** kg) = (%! ) f (x @@ [1 ], q)
70+ (b3 ** kh) = (%! ) f (x @@ [2 ], r)
71+ in (> # [b1, b2, b3] ** \ bs' =>
72+ let (x1', p') = kf (bs' @@ [0 ])
73+ (x2', q') = kg (bs' @@ [1 ])
74+ (x3', r') = kh (bs' @@ [2 ])
75+ in (> # [x1', x2', x3'], (p', (q', r')))))
76+
77+ ||| Produces a parametric map that produces `n` copies of the output, instead
78+ ||| of one, by using `n` different parameters
79+ public export
80+ sameFromTensor2 : {a, b : Type } -> Num a => Num b =>
81+ ParaAddDLens (Const a) (Const b) ->
82+ ParaAddDLens (Const (Tensor [1] a)) (Const (Tensor [2] b))
83+ sameFromTensor2 (MkPara pCont f) = MkPara
84+ (pCont >< pCont)
85+ (!%+ \ (x, (p, q)) =>
86+ let val = x @@ [0 ]
87+ (b1 ** kf) = (%! ) f (val, p)
88+ (b2 ** kg) = (%! ) f (val, q)
89+ in (> # [b1, b2] ** \ bs' =>
90+ let (x1', p') = kf (bs' @@ [0 ])
91+ (x2', q') = kg (bs' @@ [1 ])
92+ in (> # [x1' + x2'], (p', q'))))
93+
94+ public export
95+ sameFromTensor3 : {a, b : Type } -> Num a => Num b =>
96+ ParaAddDLens (Const a) (Const b) ->
97+ ParaAddDLens (Const (Tensor [1] a)) (Const (Tensor [3] b))
98+ sameFromTensor3 (MkPara pCont f) = MkPara
99+ (pCont >< pCont >< pCont)
100+ (!%+ \ (x, (p, q, r)) =>
101+ let val = x @@ [0 ]
102+ (b1 ** kf) = (%! ) f (val, p)
103+ (b2 ** kg) = (%! ) f (val, q)
104+ (b3 ** kh) = (%! ) f (val, r)
105+ in (> # [b1, b2, b3] ** \ bs' =>
106+ let (x1', p') = kf (bs' @@ [0 ])
107+ (x2', q') = kg (bs' @@ [1 ])
108+ (x3', r') = kh (bs' @@ [2 ])
109+ in (> # [x1' + x2' + x3'], (p', q', r'))))
110+
111+ ||| Produces a parametric map that produces `n` copies of the output, instead
112+ ||| of one, by using `n` different parameters
113+ public export
114+ sameFromTensor : {a, b : Type } -> Num a => Num b => {n : Nat } ->
115+ ParaAddDLens (Const a) (Const b) ->
116+ ParaAddDLens (Const (Tensor [1] a)) (Const (Tensor [n] b))
117+ sameFromTensor (MkPara pCont f) = MkPara
118+ (VectAddCont $ replicate n pCont)
119+ (!%+ \ (x, psShapes) =>
120+ let val = x @@ [0 ]
121+ outAndBw = runIdentity $ dTraverse
122+ (\ p => Id $ (%! ) f (val, p))
123+ (allToVect psShapes)
124+ out = mapPropertyRelevant (\ _ , (y ** bw) => y) outAndBw
125+ bw = mapPropertyRelevant (\ _ , (y ** bw) => bw) outAndBw
126+ in (> # constantToVect out ** \ bs' =>
127+ let tt = bw
128+ in ? bww))
129+
130+ public export
131+ sameFrom : {a : AddCont} -> ParaAddDLens a b ->
132+ ParaAddDLens a c ->
133+ ParaAddDLens a (b >< c)
134+ sameFrom (MkPara p f) (MkPara q g) = MkPara
135+ (p >< q)
136+ (!%+ \ (x, (p, q)) =>
137+ let (b ** kf) = (%! ) f (x, p)
138+ (c ** kg) = (%! ) g (x, q)
139+ in ((b, c) ** \ (b', c') =>
140+ let (x'1 , p') = kf b'
141+ (x'2 , q') = kg c'
142+ in (a. Plus x x'1 x'2 , (p', q'))))
143+
144+ public export
145+ sameFromConst : {a, b, c : Type } -> Num a => Num b => Num c =>
146+ ParaAddDLens (Const a) (Const b) ->
147+ ParaAddDLens (Const a) (Const c) ->
148+ ParaAddDLens (Const a) (Const (b, c))
149+ sameFromConst (MkPara p f) (MkPara q g) = MkPara
150+ (p >< q)
151+ (!%+ \ (x, (p, q)) =>
152+ let (b ** kf) = (%! ) f (x, p)
153+ (c ** kg) = (%! ) g (x, q)
154+ in ((b, c) ** \ (b', c') =>
155+ let (x'1 , p') = kf b'
156+ (x'2 , q') = kg c'
157+ in (x'1 + x'2 , (p', q'))))
158+
159+ public export
160+ sameFrom3 : {a : AddCont} -> ParaAddDLens a b ->
161+ ParaAddDLens a c ->
162+ ParaAddDLens a d ->
163+ ParaAddDLens a (b >< c >< d)
164+ sameFrom3 (MkPara p f) (MkPara q g) (MkPara r h) = MkPara
165+ (p >< q >< r)
166+ (!%+ \ (x, (p, q, r)) =>
167+ let (b ** kf) = (%! ) f (x, p)
168+ (c ** kg) = (%! ) g (x, q)
169+ (d ** kh) = (%! ) h (x, r)
170+ in ((b, c, d) ** \ (b', c', d') =>
171+ let (x'1 , p') = kf b'
172+ (x'2 , q') = kg c'
173+ (x'3 , r') = kh d'
174+ in (a. Plus x (a. Plus x x'1 x'2 ) x'3 , (p', q', r'))))
175+
176+ public export
177+ sameFromConst3 : {a, b, c, d : Type } -> Num a => Num b => Num c => Num d =>
178+ ParaAddDLens (Const a) (Const b) ->
179+ ParaAddDLens (Const a) (Const c) ->
180+ ParaAddDLens (Const a) (Const d) ->
181+ ParaAddDLens (Const a) (Const (b, c, d))
182+ sameFromConst3 (MkPara p f) (MkPara q g) (MkPara r h) = MkPara
183+ (p >< q >< r)
184+ (!%+ \ (x, (p, q, r)) =>
185+ let (b ** kf) = (%! ) f (x, p)
186+ (c ** kg) = (%! ) g (x, q)
187+ (d ** kh) = (%! ) h (x, r)
188+ in ((b, c, d) ** \ (b', c', d') =>
189+ let (x'1 , p') = kf b'
190+ (x'2 , q') = kg c'
191+ (x'3 , r') = kh d'
192+ in (x'1 + x'2 + x'3 , (p', q', r'))))
193+
194+ -- ||| N-ary probability intro and elimination
195+ -- NProbIntro : {ef : EffectType} ->
196+ -- {i : Nat} -> IsSucc i =>
197+ -- {ts : Vect i Ty} ->
198+ -- All (\t => Term ef ctx t) ts -> -- for now all the components need to run with the same effect
199+ -- -- Treating probability as logits
200+ -- Vect i (Term ef ctx Number) ->
201+ -- Term Prob ctx (NProb ts)
202+ -- NProbElim : {ef : EffectType} ->
203+ -- {i : Nat} -> IsSucc i =>
204+ -- {ts : Vect i Ty} ->
205+ -- Term ef ctx (NProb ts) ->
206+ -- All (\e => Term ef (e :: ctx) c) ts -> Term Prob ctx c
0 commit comments