Skip to content

Commit a4c6cbc

Browse files
committed
[new] sampling, monadic additive lenses, para
1 parent d016c6f commit a4c6cbc

File tree

34 files changed

+1731
-422
lines changed

34 files changed

+1731
-422
lines changed

src/Control/Monad/Distribution.idr

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,34 @@ module Control.Monad.Distribution
33
import Data.Vect
44
import Data.Vect.Quantifiers
55
import Control.Monad.Identity
6-
import Misc
76
import System.Random
87

9-
||| Convex combination of a finite set of types
8+
import Data.Num
9+
import Misc
10+
11+
||| Convex combination of a finite set of types, a point in a simplex △^(i-1)
12+
||| i=2 -> △¹ -> line segment
13+
||| i=3 -> △² -> triangle
14+
||| ...
15+
||| Probabilities are here represented as logits
1016
||| Since this is used in `Data.Container.Products.ConvexComb`, we cannot use
1117
||| `Tensor` here
1218
public export
1319
data Dist : (i : Nat) -> Type where
1420
||| Probabilities are represented as logits
1521
MkDist : Vect i Double -> Dist i
1622

23+
public export
24+
toVect : Dist i -> Vect i Double
25+
toVect (MkDist xs) = xs
1726

27+
||| Logit representation of the uniform distribution
1828
public export
1929
uniform : {i : Nat} -> (isSucc : IsSucc i) => Dist i
20-
uniform = MkDist (replicate i 1)
30+
uniform = MkDist (replicate i 0)
31+
32+
||| Logit representation of dirac delta
33+
public export
34+
diracDelta : {i : Nat} -> IsSucc i =>
35+
(j : Fin i) -> Dist i
36+
diracDelta @{ItIsSucc {n}} j = MkDist $ insertAt j 0 (replicate n minusInfinity)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module Control.Monad.Sample.Definition
2+
3+
import Data.Fin
4+
5+
import Control.Monad.Distribution
6+
7+
||| Interface for sampling from a distribution
8+
||| We require that there is at least one element in the distribution
9+
||| TODO add temperature as a implicit parameter with a defualt value of 1.0
10+
public export
11+
interface Monad m => MonadSample m where
12+
sample : {i : Nat} -> (isSucc : IsSucc i) =>
13+
Dist i -> m (Fin i)
Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,13 @@
1-
module Control.Monad.Sample
1+
module Control.Monad.Sample.Instances
22

3+
import Control.Monad.Identity
34
import System.Random
45

5-
import Data.Tensor
6-
6+
import Data.Tensor
77
import Control.Monad.Distribution
8-
import Control.Monad.Identity
8+
import Control.Monad.Sample.Definition
99
import NN.Architectures.Softargmax
1010

11-
import Misc
12-
13-
||| Interface for sampling from a distribution
14-
||| We require that there is at least one element in the distribution
15-
||| TODO add temperature as a implicit parameter with a defualt value of 1.0
16-
public export
17-
interface Monad m => MonadSample m where
18-
sample : {i : Nat} -> (isSucc : IsSucc i) =>
19-
Dist i -> m (Fin i)
20-
21-
2211
||| Trivial sampler, always picks the first element
2312
public export
2413
[pickFirst] MonadSample Identity where
@@ -55,3 +44,12 @@ testIO = do
5544
-- printLn is
5645
printLn (count (== 0) is) -- should be ~100
5746
printLn (count (== 1) is) -- should be ~900
47+
48+
public export
49+
testDirac : IO ()
50+
testDirac = do
51+
let index = 4
52+
let logits = diracDelta {i=10} index
53+
inds <- sequence (replicate 1000 (sample logits))
54+
printLn (take 10 inds)
55+
printLn (count (== index) inds)

src/Data/Autodiff/MLens.idr

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/Data/Autodiff/Ops.idr

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
module Data.Autodiff.Ops
2+
3+
import Data.Tensor
4+
import Data.Container.Additive
5+
import Data.Para
6+
import Control.Monad.Distribution
7+
import Control.Monad.Identity
8+
import Data.ComMonoid
9+
10+
import Misc
11+
12+
%hide Data.Container.Base.Morphism.Definition.DependentLenses.(=%>)
13+
%hide Syntax.WithProof.prefix.(@@)
14+
15+
-- This is here and not in Container.Additive because it uses `Tensor`
16+
public export
17+
Simplex : Nat -> AddCont
18+
Simplex n = MkAddCont $ (_ : Dist n) !> (Tensor [n] Double)
19+
20+
public export
21+
MulParametric : {a : Type} -> Num a => ParaAddDLens (Const a) (Const a)
22+
MulParametric = binaryOpToPara {p=Const a} Mul
23+
24+
public export
25+
AddParametric : {a : Type} -> Num a => ParaAddDLens (Const a) (Const a)
26+
AddParametric = binaryOpToPara {p=Const a} Sum
27+
28+
public export
29+
AffineParametric : {a : Type} -> Num a => ParaAddDLens (Const a) (Const a)
30+
AffineParametric = composePara MulParametric AddParametric
31+
32+
public export
33+
LeakyReLU : {a : Type} -> Num a => Ord a => FromDouble a =>
34+
{default 0.01 alpha : a} ->
35+
ParaAddDLens (Const a) (Const a)
36+
LeakyReLU = trivialParam (!%+ \x =>
37+
(if x > 0 then x else alpha * x ** \x' => if x > 0 then x' else alpha))
38+
39+
public export
40+
LeakyReLUTensor : {a : Type} -> {n : Nat} -> Num a => Ord a => FromDouble a =>
41+
{default 0.01 alpha : a} ->
42+
ParaAddDLens (Const (Tensor [n] a)) (Const (Tensor [n] a))
43+
LeakyReLUTensor = trivialParam (!%+ \x =>
44+
(x <&> (\xx => if xx > 0 then xx else alpha * xx) ** \dy =>
45+
(\(d, xx) => if xx > 0 then d else alpha * d) <$> liftA2Tensor dy x))
46+
47+
public export
48+
parallelTensor2 : {a, b: Type} -> Num a => Num b =>
49+
ParaAddDLens (Const a) (Const b) ->
50+
ParaAddDLens (Const (Tensor [2] a)) (Const (Tensor [2] b))
51+
parallelTensor2 (MkPara pCont f) = MkPara
52+
(pCont >< pCont)
53+
(!%+ \(x, (p, q)) =>
54+
let (b1 ** kf) = (%!) f (x @@ [0], p)
55+
(b2 ** kg) = (%!) f (x @@ [1], q)
56+
in (># [b1, b2] ** \bs' =>
57+
let (x1', p') = kf (bs' @@ [0])
58+
(x2', q') = kg (bs' @@ [1])
59+
in (># [x1', x2'], (p', q'))))
60+
61+
public export
62+
parallelTensor3 : {a, b : Type} -> Num a => Num b =>
63+
ParaAddDLens (Const a) (Const b) ->
64+
ParaAddDLens (Const (Tensor [3] a)) (Const (Tensor [3] b))
65+
parallelTensor3 (MkPara pCont f) = MkPara
66+
(pCont >< pCont >< pCont)
67+
(!%+ \(x, (p, q, r)) =>
68+
let (b1 ** kf) = (%!) f (x @@ [0], p)
69+
(b2 ** kg) = (%!) f (x @@ [1], q)
70+
(b3 ** kh) = (%!) f (x @@ [2], r)
71+
in (># [b1, b2, b3] ** \bs' =>
72+
let (x1', p') = kf (bs' @@ [0])
73+
(x2', q') = kg (bs' @@ [1])
74+
(x3', r') = kh (bs' @@ [2])
75+
in (># [x1', x2', x3'], (p', (q', r')))))
76+
77+
||| Produces a parametric map that produces `n` copies of the output, instead
78+
||| of one, by using `n` different parameters
79+
public export
80+
sameFromTensor2 : {a, b : Type} -> Num a => Num b =>
81+
ParaAddDLens (Const a) (Const b) ->
82+
ParaAddDLens (Const (Tensor [1] a)) (Const (Tensor [2] b))
83+
sameFromTensor2 (MkPara pCont f) = MkPara
84+
(pCont >< pCont)
85+
(!%+ \(x, (p, q)) =>
86+
let val = x @@ [0]
87+
(b1 ** kf) = (%!) f (val, p)
88+
(b2 ** kg) = (%!) f (val, q)
89+
in (># [b1, b2] ** \bs' =>
90+
let (x1', p') = kf (bs' @@ [0])
91+
(x2', q') = kg (bs' @@ [1])
92+
in (># [x1' + x2'], (p', q'))))
93+
94+
public export
95+
sameFromTensor3 : {a, b : Type} -> Num a => Num b =>
96+
ParaAddDLens (Const a) (Const b) ->
97+
ParaAddDLens (Const (Tensor [1] a)) (Const (Tensor [3] b))
98+
sameFromTensor3 (MkPara pCont f) = MkPara
99+
(pCont >< pCont >< pCont)
100+
(!%+ \(x, (p, q, r)) =>
101+
let val = x @@ [0]
102+
(b1 ** kf) = (%!) f (val, p)
103+
(b2 ** kg) = (%!) f (val, q)
104+
(b3 ** kh) = (%!) f (val, r)
105+
in (># [b1, b2, b3] ** \bs' =>
106+
let (x1', p') = kf (bs' @@ [0])
107+
(x2', q') = kg (bs' @@ [1])
108+
(x3', r') = kh (bs' @@ [2])
109+
in (># [x1' + x2' + x3'], (p', q', r'))))
110+
111+
||| Produces a parametric map that produces `n` copies of the output, instead
112+
||| of one, by using `n` different parameters
113+
public export
114+
sameFromTensor : {a, b : Type} -> Num a => Num b => {n : Nat} ->
115+
ParaAddDLens (Const a) (Const b) ->
116+
ParaAddDLens (Const (Tensor [1] a)) (Const (Tensor [n] b))
117+
sameFromTensor (MkPara pCont f) = MkPara
118+
(VectAddCont $ replicate n pCont)
119+
(!%+ \(x, psShapes) =>
120+
let val = x @@ [0]
121+
outAndBw = runIdentity $ dTraverse
122+
(\p => Id $ (%!) f (val, p))
123+
(allToVect psShapes)
124+
out = mapPropertyRelevant (\_, (y ** bw) => y) outAndBw
125+
bw = mapPropertyRelevant (\_, (y ** bw) => bw) outAndBw
126+
in (># constantToVect out ** \bs' =>
127+
let tt = bw
128+
in ?bww))
129+
130+
public export
131+
sameFrom : {a : AddCont} -> ParaAddDLens a b ->
132+
ParaAddDLens a c ->
133+
ParaAddDLens a (b >< c)
134+
sameFrom (MkPara p f) (MkPara q g) = MkPara
135+
(p >< q)
136+
(!%+ \(x, (p, q)) =>
137+
let (b ** kf) = (%!) f (x, p)
138+
(c ** kg) = (%!) g (x, q)
139+
in ((b, c) ** \(b', c') =>
140+
let (x'1, p') = kf b'
141+
(x'2, q') = kg c'
142+
in (a.Plus x x'1 x'2, (p', q'))))
143+
144+
public export
145+
sameFromConst : {a, b, c : Type} -> Num a => Num b => Num c =>
146+
ParaAddDLens (Const a) (Const b) ->
147+
ParaAddDLens (Const a) (Const c) ->
148+
ParaAddDLens (Const a) (Const (b, c))
149+
sameFromConst (MkPara p f) (MkPara q g) = MkPara
150+
(p >< q)
151+
(!%+ \(x, (p, q)) =>
152+
let (b ** kf) = (%!) f (x, p)
153+
(c ** kg) = (%!) g (x, q)
154+
in ((b, c) ** \(b', c') =>
155+
let (x'1, p') = kf b'
156+
(x'2, q') = kg c'
157+
in (x'1 + x'2, (p', q'))))
158+
159+
public export
160+
sameFrom3 : {a : AddCont} -> ParaAddDLens a b ->
161+
ParaAddDLens a c ->
162+
ParaAddDLens a d ->
163+
ParaAddDLens a (b >< c >< d)
164+
sameFrom3 (MkPara p f) (MkPara q g) (MkPara r h) = MkPara
165+
(p >< q >< r)
166+
(!%+ \(x, (p, q, r)) =>
167+
let (b ** kf) = (%!) f (x, p)
168+
(c ** kg) = (%!) g (x, q)
169+
(d ** kh) = (%!) h (x, r)
170+
in ((b, c, d) ** \(b', c', d') =>
171+
let (x'1, p') = kf b'
172+
(x'2, q') = kg c'
173+
(x'3, r') = kh d'
174+
in (a.Plus x (a.Plus x x'1 x'2) x'3, (p', q', r'))))
175+
176+
public export
177+
sameFromConst3 : {a, b, c, d : Type} -> Num a => Num b => Num c => Num d =>
178+
ParaAddDLens (Const a) (Const b) ->
179+
ParaAddDLens (Const a) (Const c) ->
180+
ParaAddDLens (Const a) (Const d) ->
181+
ParaAddDLens (Const a) (Const (b, c, d))
182+
sameFromConst3 (MkPara p f) (MkPara q g) (MkPara r h) = MkPara
183+
(p >< q >< r)
184+
(!%+ \(x, (p, q, r)) =>
185+
let (b ** kf) = (%!) f (x, p)
186+
(c ** kg) = (%!) g (x, q)
187+
(d ** kh) = (%!) h (x, r)
188+
in ((b, c, d) ** \(b', c', d') =>
189+
let (x'1, p') = kf b'
190+
(x'2, q') = kg c'
191+
(x'3, r') = kh d'
192+
in (x'1 + x'2 + x'3, (p', q', r'))))
193+
194+
-- ||| N-ary probability intro and elimination
195+
-- NProbIntro : {ef : EffectType} ->
196+
-- {i : Nat} -> IsSucc i =>
197+
-- {ts : Vect i Ty} ->
198+
-- All (\t => Term ef ctx t) ts -> -- for now all the components need to run with the same effect
199+
-- -- Treating probability as logits
200+
-- Vect i (Term ef ctx Number) ->
201+
-- Term Prob ctx (NProb ts)
202+
-- NProbElim : {ef : EffectType} ->
203+
-- {i : Nat} -> IsSucc i =>
204+
-- {ts : Vect i Ty} ->
205+
-- Term ef ctx (NProb ts) ->
206+
-- All (\e => Term ef (e :: ctx) c) ts -> Term Prob ctx c

src/Data/CT/Category/Instances.idr

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ public export
1010
TypeCat : Cat
1111
TypeCat = MkCat Type (\a, b => a -> b)
1212

13+
public export
14+
Kleisli : Monad m => Cat
15+
Kleisli = MkCat Type (\a, b => a -> m b)
16+
1317
public export
1418
Cat : Cat
1519
Cat = MkCat Cat Functor
@@ -35,3 +39,11 @@ AddDLens = MkCat AddCont (=%>)
3539
public export
3640
AddDChart : Cat
3741
AddDChart = MkCat AddCont (=&>)
42+
43+
public export
44+
MLens : Monad m => Cat
45+
MLens = MkCat Cont (MLens {m=m})
46+
47+
public export
48+
MAddLens : Monad m => Cat
49+
MAddLens = MkCat AddCont (MAddLens {m=m})

src/Data/CT/DependentAction/Instances.idr

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,35 @@ namespace AddCont
5151
DPairAddCont = MkDepAct $ \c => MkFunctor
5252
(DepHancockProduct c)
5353
(\r => !%+ \(x ** p) => ((x ** (r x).fwd p) **
54-
\(x', p') => (x', (r x).bwd p p')))
54+
\(x', p') => (x', (r x).bwd p p')))
55+
56+
57+
namespace MLens
58+
public export
59+
PairMLens : Monad m => DepAct MLens (Const {c=MLens {m}})
60+
PairMLens = MkDepAct $ \c => MkFunctor
61+
(c ><)
62+
(hancockMap id)
63+
64+
public export
65+
DPairMLens : Monad m => DepAct MLens (FamMLens {m=m} {c=MLens {m}})
66+
DPairMLens = MkDepAct $ \c => MkFunctor
67+
(DepHancockProduct c)
68+
(\r => !%% \(x ** p) => do
69+
(y ** ky) <- (%%! r x) p
70+
pure ((x ** y) ** \x'y' => (fst x'y', ky (snd x'y'))))
71+
72+
namespace MAddLens
73+
public export
74+
PairMAddLens : Monad m => DepAct MAddLens (Const {c=MAddLens {m}})
75+
PairMAddLens = MkDepAct $ \c => MkFunctor
76+
(c ><)
77+
(id ><)
78+
79+
public export
80+
DPairMAddLens : Monad m => DepAct MAddLens (FamMAddLens {m=m} {c=MAddLens {m}})
81+
DPairMAddLens = MkDepAct $ \c => MkFunctor
82+
(DepHancockProduct c)
83+
(\r => !%%+ \(x ** p) => do
84+
(y ** ky) <- (%%!+ r x) p
85+
pure ((x ** y) ** \x'y' => (fst x'y', ky (snd x'y'))))

0 commit comments

Comments
 (0)