You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/notebooks/introduction.ipynb
+68-15Lines changed: 68 additions & 15 deletions
Original file line number
Diff line number
Diff line change
@@ -3,7 +3,11 @@
3
3
{
4
4
"cell_type": "markdown",
5
5
"id": "4f6f4229-6b15-4e2b-89af-8957708479d7",
6
-
"metadata": {},
6
+
"metadata": {
7
+
"pycharm": {
8
+
"name": "#%% md\n"
9
+
}
10
+
},
7
11
"source": [
8
12
"# Constructing normalizing flows\n",
9
13
"\n",
@@ -15,7 +19,10 @@
15
19
"execution_count": 4,
16
20
"id": "9497e202-3f4e-4602-9e90-669545f18816",
17
21
"metadata": {
18
-
"tags": []
22
+
"tags": [],
23
+
"pycharm": {
24
+
"name": "#%%\n"
25
+
}
19
26
},
20
27
"outputs": [],
21
28
"source": [
@@ -28,7 +35,11 @@
28
35
{
29
36
"cell_type": "markdown",
30
37
"id": "bfd96e3f-ef84-454f-a611-0cdc0a629d2d",
31
-
"metadata": {},
38
+
"metadata": {
39
+
"pycharm": {
40
+
"name": "#%% md\n"
41
+
}
42
+
},
32
43
"source": [
33
44
"## How to construct a Haiku module\n",
34
45
"\n",
@@ -40,7 +51,10 @@
40
51
"execution_count": 30,
41
52
"id": "c9db6163-6cb8-4968-8e3d-90d35a5094cf",
42
53
"metadata": {
43
-
"tags": []
54
+
"tags": [],
55
+
"pycharm": {
56
+
"name": "#%%\n"
57
+
}
44
58
},
45
59
"outputs": [],
46
60
"source": [
@@ -53,7 +67,10 @@
53
67
"execution_count": 61,
54
68
"id": "f78c2fc0-f4f6-476d-8b15-12b4a874ce1d",
55
69
"metadata": {
56
-
"tags": []
70
+
"tags": [],
71
+
"pycharm": {
72
+
"name": "#%%\n"
73
+
}
57
74
},
58
75
"outputs": [],
59
76
"source": [
@@ -88,7 +105,10 @@
88
105
"cell_type": "markdown",
89
106
"id": "b41e89a1-6df4-4069-a606-4f2412b9dc6a",
90
107
"metadata": {
91
-
"tags": []
108
+
"tags": [],
109
+
"pycharm": {
110
+
"name": "#%% md\n"
111
+
}
92
112
},
93
113
"source": [
94
114
"Constructing a Haiku module needs to be done within a `hk.transform` block. This can either be done by providing a function like here and an object. In our case we are using `hk.transform` on `pushforward(**kwargs)` which calls\n",
@@ -98,7 +118,11 @@
98
118
{
99
119
"cell_type": "markdown",
100
120
"id": "689f867c-2259-4713-81b6-5352837cb342",
101
-
"metadata": {},
121
+
"metadata": {
122
+
"pycharm": {
123
+
"name": "#%% md\n"
124
+
}
125
+
},
102
126
"source": [
103
127
"We can now initialize the flow. Let's define a random data set first and then initialize the parameters."
104
128
]
@@ -108,7 +132,10 @@
108
132
"execution_count": 65,
109
133
"id": "cf91f2ce-6ba9-438c-95a7-ec9dcfbfea17",
110
134
"metadata": {
111
-
"tags": []
135
+
"tags": [],
136
+
"pycharm": {
137
+
"name": "#%%\n"
138
+
}
112
139
},
113
140
"outputs": [
114
141
{
@@ -150,7 +177,11 @@
150
177
{
151
178
"cell_type": "markdown",
152
179
"id": "4c8f8b8e-4e81-4b54-a1c4-72027db33e1c",
153
-
"metadata": {},
180
+
"metadata": {
181
+
"pycharm": {
182
+
"name": "#%% md\n"
183
+
}
184
+
},
154
185
"source": [
155
186
"The only trainable paramaters that are flow defines are the weights of the MLP. The MLP is used to compute the conditional probability density inside the `decoder_fn` function. \n",
156
187
"The `Slice` surjector itself doesn't have paramters."
@@ -159,7 +190,11 @@
159
190
{
160
191
"cell_type": "markdown",
161
192
"id": "84d8368a-2c75-4e72-a3b4-45f4cfc71fbf",
162
-
"metadata": {},
193
+
"metadata": {
194
+
"pycharm": {
195
+
"name": "#%% md\n"
196
+
}
197
+
},
163
198
"source": [
164
199
"We can now test the flow. Let's sample some data first."
165
200
]
@@ -169,7 +204,10 @@
169
204
"execution_count": 71,
170
205
"id": "77ae282d-9e30-49c9-98b1-6f9494b34a21",
171
206
"metadata": {
172
-
"tags": []
207
+
"tags": [],
208
+
"pycharm": {
209
+
"name": "#%%\n"
210
+
}
173
211
},
174
212
"outputs": [
175
213
{
@@ -196,7 +234,11 @@
196
234
{
197
235
"cell_type": "markdown",
198
236
"id": "04fd204d-f0c2-468b-a64d-ec292b071f34",
199
-
"metadata": {},
237
+
"metadata": {
238
+
"pycharm": {
239
+
"name": "#%% md\n"
240
+
}
241
+
},
200
242
"source": [
201
243
"As mentioned above, in order to dispatch to a method, we just provide a keyword argument. In this case this is `method='sample'`. Computing the log probability of the data can be done, by changing the method argument to `log_prob`."
202
244
]
@@ -206,7 +248,10 @@
206
248
"execution_count": 72,
207
249
"id": "2190c369-4242-45ec-9ac9-b9fa5bccc1dd",
208
250
"metadata": {
209
-
"tags": []
251
+
"tags": [],
252
+
"pycharm": {
253
+
"name": "#%%\n"
254
+
}
210
255
},
211
256
"outputs": [
212
257
{
@@ -227,7 +272,11 @@
227
272
{
228
273
"cell_type": "markdown",
229
274
"id": "8c046ffd-a2b9-4348-b8f3-85e978b7998f",
230
-
"metadata": {},
275
+
"metadata": {
276
+
"pycharm": {
277
+
"name": "#%% md\n"
278
+
}
279
+
},
231
280
"source": [
232
281
"## How to construct `TransformedDistribution` objects\n",
0 commit comments