|
3 | 3 | {
|
4 | 4 | "cell_type": "code",
|
5 | 5 | "execution_count": null,
|
6 |
| - "id": "frequent-field", |
| 6 | + "id": "0", |
7 | 7 | "metadata": {
|
8 | 8 | "tags": []
|
9 | 9 | },
|
|
18 | 18 | {
|
19 | 19 | "cell_type": "code",
|
20 | 20 | "execution_count": null,
|
21 |
| - "id": "opened-virgin", |
| 21 | + "id": "1", |
22 | 22 | "metadata": {
|
23 | 23 | "tags": []
|
24 | 24 | },
|
|
31 | 31 | },
|
32 | 32 | {
|
33 | 33 | "cell_type": "markdown",
|
34 |
| - "id": "lasting-express", |
| 34 | + "id": "2", |
35 | 35 | "metadata": {},
|
36 | 36 | "source": [
|
37 | 37 | "# Optimized Learning\n",
|
|
41 | 41 | },
|
42 | 42 | {
|
43 | 43 | "cell_type": "markdown",
|
44 |
| - "id": "forward-process", |
| 44 | + "id": "3", |
45 | 45 | "metadata": {},
|
46 | 46 | "source": [
|
47 | 47 | "## Autograd to JAX\n",
|
|
52 | 52 | },
|
53 | 53 | {
|
54 | 54 | "cell_type": "markdown",
|
55 |
| - "id": "correct-cyprus", |
| 55 | + "id": "4", |
56 | 56 | "metadata": {},
|
57 | 57 | "source": [
|
58 | 58 | "## Example: Transforming a function into its derivative\n",
|
|
68 | 68 | {
|
69 | 69 | "cell_type": "code",
|
70 | 70 | "execution_count": null,
|
71 |
| - "id": "demanding-opportunity", |
| 71 | + "id": "5", |
72 | 72 | "metadata": {
|
73 | 73 | "tags": []
|
74 | 74 | },
|
|
90 | 90 | },
|
91 | 91 | {
|
92 | 92 | "cell_type": "markdown",
|
93 |
| - "id": "forty-lindsay", |
| 93 | + "id": "6", |
94 | 94 | "metadata": {},
|
95 | 95 | "source": [
|
96 | 96 | "Here's another example using a polynomial function:\n",
|
|
105 | 105 | {
|
106 | 106 | "cell_type": "code",
|
107 | 107 | "execution_count": null,
|
108 |
| - "id": "neutral-neighbor", |
| 108 | + "id": "7", |
109 | 109 | "metadata": {
|
110 | 110 | "tags": []
|
111 | 111 | },
|
|
128 | 128 | },
|
129 | 129 | {
|
130 | 130 | "cell_type": "markdown",
|
131 |
| - "id": "steady-bikini", |
| 131 | + "id": "8", |
132 | 132 | "metadata": {},
|
133 | 133 | "source": [
|
134 | 134 | "## Using grad to solve minimization problems\n",
|
|
147 | 147 | {
|
148 | 148 | "cell_type": "code",
|
149 | 149 | "execution_count": null,
|
150 |
| - "id": "opponent-modification", |
| 150 | + "id": "9", |
151 | 151 | "metadata": {
|
152 | 152 | "tags": []
|
153 | 153 | },
|
|
163 | 163 | },
|
164 | 164 | {
|
165 | 165 | "cell_type": "markdown",
|
166 |
| - "id": "beautiful-theory", |
| 166 | + "id": "10", |
167 | 167 | "metadata": {},
|
168 | 168 | "source": [
|
169 | 169 | "We know from calculus that the sign of the second derivative tells us whether we have a minima or maxima at a point.\n",
|
|
178 | 178 | {
|
179 | 179 | "cell_type": "code",
|
180 | 180 | "execution_count": null,
|
181 |
| - "id": "former-syracuse", |
| 181 | + "id": "11", |
182 | 182 | "metadata": {},
|
183 | 183 | "outputs": [],
|
184 | 184 | "source": [
|
|
189 | 189 | },
|
190 | 190 | {
|
191 | 191 | "cell_type": "markdown",
|
192 |
| - "id": "surrounded-plain", |
| 192 | + "id": "12", |
193 | 193 | "metadata": {},
|
194 | 194 | "source": [
|
195 | 195 | "Grad is composable an arbitrary number of times. You can keep calling grad as many times as you like."
|
196 | 196 | ]
|
197 | 197 | },
|
198 | 198 | {
|
199 | 199 | "cell_type": "markdown",
|
200 |
| - "id": "brazilian-atlas", |
| 200 | + "id": "13", |
201 | 201 | "metadata": {},
|
202 | 202 | "source": [
|
203 | 203 | "## Maximum likelihood estimation\n",
|
|
216 | 216 | {
|
217 | 217 | "cell_type": "code",
|
218 | 218 | "execution_count": null,
|
219 |
| - "id": "confidential-sympathy", |
| 219 | + "id": "14", |
220 | 220 | "metadata": {
|
221 | 221 | "tags": []
|
222 | 222 | },
|
|
236 | 236 | },
|
237 | 237 | {
|
238 | 238 | "cell_type": "markdown",
|
239 |
| - "id": "atlantic-excellence", |
| 239 | + "id": "15", |
240 | 240 | "metadata": {},
|
241 | 241 | "source": [
|
242 | 242 | "Our estimation task will necessitate calculating the total joint log likelihood of our data under a Gaussian model.\n",
|
|
248 | 248 | {
|
249 | 249 | "cell_type": "code",
|
250 | 250 | "execution_count": null,
|
251 |
| - "id": "known-terrain", |
| 251 | + "id": "16", |
252 | 252 | "metadata": {
|
253 | 253 | "tags": []
|
254 | 254 | },
|
|
263 | 263 | },
|
264 | 264 | {
|
265 | 265 | "cell_type": "markdown",
|
266 |
| - "id": "terminal-census", |
| 266 | + "id": "17", |
267 | 267 | "metadata": {},
|
268 | 268 | "source": [
|
269 | 269 | "If you're wondering why we use `log_sigma` rather than `sigma`, it is a choice made for practical reasons.\n",
|
|
280 | 280 | {
|
281 | 281 | "cell_type": "code",
|
282 | 282 | "execution_count": null,
|
283 |
| - "id": "dominant-delight", |
| 283 | + "id": "18", |
284 | 284 | "metadata": {
|
285 | 285 | "tags": []
|
286 | 286 | },
|
|
293 | 293 | },
|
294 | 294 | {
|
295 | 295 | "cell_type": "markdown",
|
296 |
| - "id": "equal-brazilian", |
| 296 | + "id": "19", |
297 | 297 | "metadata": {},
|
298 | 298 | "source": [
|
299 | 299 | "Now, we can create the gradient function of our negative log likelihood.\n",
|
|
307 | 307 | {
|
308 | 308 | "cell_type": "code",
|
309 | 309 | "execution_count": null,
|
310 |
| - "id": "meaning-scanning", |
| 310 | + "id": "20", |
311 | 311 | "metadata": {
|
312 | 312 | "tags": []
|
313 | 313 | },
|
|
322 | 322 | },
|
323 | 323 | {
|
324 | 324 | "cell_type": "markdown",
|
325 |
| - "id": "hourly-miller", |
| 325 | + "id": "21", |
326 | 326 | "metadata": {},
|
327 | 327 | "source": [
|
328 | 328 | "Now, we can do the gradient descent step!"
|
|
331 | 331 | {
|
332 | 332 | "cell_type": "code",
|
333 | 333 | "execution_count": null,
|
334 |
| - "id": "cosmetic-perception", |
| 334 | + "id": "22", |
335 | 335 | "metadata": {
|
336 | 336 | "tags": []
|
337 | 337 | },
|
|
347 | 347 | },
|
348 | 348 | {
|
349 | 349 | "cell_type": "markdown",
|
350 |
| - "id": "defensive-family", |
| 350 | + "id": "23", |
351 | 351 | "metadata": {},
|
352 | 352 | "source": [
|
353 | 353 | "And voila! We have gradient descended our way to the maximum likelihood parameters :)."
|
354 | 354 | ]
|
355 | 355 | },
|
356 | 356 | {
|
357 | 357 | "cell_type": "markdown",
|
358 |
| - "id": "constant-account", |
| 358 | + "id": "24", |
359 | 359 | "metadata": {},
|
360 | 360 | "source": [
|
361 | 361 | "## Exercise: Where is the gold? It's at the minima!\n",
|
|
368 | 368 | {
|
369 | 369 | "cell_type": "code",
|
370 | 370 | "execution_count": null,
|
371 |
| - "id": "focal-climate", |
| 371 | + "id": "25", |
372 | 372 | "metadata": {
|
373 | 373 | "tags": []
|
374 | 374 | },
|
|
383 | 383 | },
|
384 | 384 | {
|
385 | 385 | "cell_type": "markdown",
|
386 |
| - "id": "massive-corps", |
| 386 | + "id": "26", |
387 | 387 | "metadata": {},
|
388 | 388 | "source": [
|
389 | 389 | "It should be evident from here that there are two minima in the function.\n",
|
|
398 | 398 | {
|
399 | 399 | "cell_type": "code",
|
400 | 400 | "execution_count": null,
|
401 |
| - "id": "opened-beads", |
| 401 | + "id": "27", |
402 | 402 | "metadata": {
|
403 | 403 | "tags": []
|
404 | 404 | },
|
|
420 | 420 | },
|
421 | 421 | {
|
422 | 422 | "cell_type": "markdown",
|
423 |
| - "id": "brown-violation", |
| 423 | + "id": "28", |
424 | 424 | "metadata": {},
|
425 | 425 | "source": [
|
426 | 426 | "Now, implement the optimization loop!"
|
|
429 | 429 | {
|
430 | 430 | "cell_type": "code",
|
431 | 431 | "execution_count": null,
|
432 |
| - "id": "alternative-wisdom", |
| 432 | + "id": "29", |
433 | 433 | "metadata": {
|
434 | 434 | "tags": []
|
435 | 435 | },
|
|
450 | 450 | },
|
451 | 451 | {
|
452 | 452 | "cell_type": "markdown",
|
453 |
| - "id": "alternative-iraqi", |
| 453 | + "id": "30", |
454 | 454 | "metadata": {},
|
455 | 455 | "source": [
|
456 | 456 | "## Exercise: programming a robot that only moves along one axis\n",
|
|
464 | 464 | {
|
465 | 465 | "cell_type": "code",
|
466 | 466 | "execution_count": null,
|
467 |
| - "id": "operational-advantage", |
| 467 | + "id": "31", |
468 | 468 | "metadata": {
|
469 | 469 | "tags": []
|
470 | 470 | },
|
|
490 | 490 | },
|
491 | 491 | {
|
492 | 492 | "cell_type": "markdown",
|
493 |
| - "id": "ecological-asian", |
| 493 | + "id": "32", |
494 | 494 | "metadata": {},
|
495 | 495 | "source": [
|
496 | 496 | "For your reference we have the function plotted below."
|
|
499 | 499 | {
|
500 | 500 | "cell_type": "code",
|
501 | 501 | "execution_count": null,
|
502 |
| - "id": "convenient-optics", |
| 502 | + "id": "33", |
503 | 503 | "metadata": {
|
504 | 504 | "tags": []
|
505 | 505 | },
|
|
531 | 531 | {
|
532 | 532 | "cell_type": "code",
|
533 | 533 | "execution_count": null,
|
534 |
| - "id": "loaded-labor", |
| 534 | + "id": "34", |
535 | 535 | "metadata": {},
|
536 | 536 | "outputs": [],
|
537 | 537 | "source": []
|
|
0 commit comments