Skip to content

Commit 53f352f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f3a8526 commit 53f352f

File tree

4 files changed

+39
-37
lines changed

4 files changed

+39
-37
lines changed

Diff for: notebooks/02-jax-idioms/04-optimized-learning.ipynb

+35-35
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
{
44
"cell_type": "code",
55
"execution_count": null,
6-
"id": "frequent-field",
6+
"id": "0",
77
"metadata": {
88
"tags": []
99
},
@@ -18,7 +18,7 @@
1818
{
1919
"cell_type": "code",
2020
"execution_count": null,
21-
"id": "opened-virgin",
21+
"id": "1",
2222
"metadata": {
2323
"tags": []
2424
},
@@ -31,7 +31,7 @@
3131
},
3232
{
3333
"cell_type": "markdown",
34-
"id": "lasting-express",
34+
"id": "2",
3535
"metadata": {},
3636
"source": [
3737
"# Optimized Learning\n",
@@ -41,7 +41,7 @@
4141
},
4242
{
4343
"cell_type": "markdown",
44-
"id": "forward-process",
44+
"id": "3",
4545
"metadata": {},
4646
"source": [
4747
"## Autograd to JAX\n",
@@ -52,7 +52,7 @@
5252
},
5353
{
5454
"cell_type": "markdown",
55-
"id": "correct-cyprus",
55+
"id": "4",
5656
"metadata": {},
5757
"source": [
5858
"## Example: Transforming a function into its derivative\n",
@@ -68,7 +68,7 @@
6868
{
6969
"cell_type": "code",
7070
"execution_count": null,
71-
"id": "demanding-opportunity",
71+
"id": "5",
7272
"metadata": {
7373
"tags": []
7474
},
@@ -90,7 +90,7 @@
9090
},
9191
{
9292
"cell_type": "markdown",
93-
"id": "forty-lindsay",
93+
"id": "6",
9494
"metadata": {},
9595
"source": [
9696
"Here's another example using a polynomial function:\n",
@@ -105,7 +105,7 @@
105105
{
106106
"cell_type": "code",
107107
"execution_count": null,
108-
"id": "neutral-neighbor",
108+
"id": "7",
109109
"metadata": {
110110
"tags": []
111111
},
@@ -128,7 +128,7 @@
128128
},
129129
{
130130
"cell_type": "markdown",
131-
"id": "steady-bikini",
131+
"id": "8",
132132
"metadata": {},
133133
"source": [
134134
"## Using grad to solve minimization problems\n",
@@ -147,7 +147,7 @@
147147
{
148148
"cell_type": "code",
149149
"execution_count": null,
150-
"id": "opponent-modification",
150+
"id": "9",
151151
"metadata": {
152152
"tags": []
153153
},
@@ -163,7 +163,7 @@
163163
},
164164
{
165165
"cell_type": "markdown",
166-
"id": "beautiful-theory",
166+
"id": "10",
167167
"metadata": {},
168168
"source": [
169169
"We know from calculus that the sign of the second derivative tells us whether we have a minima or maxima at a point.\n",
@@ -178,7 +178,7 @@
178178
{
179179
"cell_type": "code",
180180
"execution_count": null,
181-
"id": "former-syracuse",
181+
"id": "11",
182182
"metadata": {},
183183
"outputs": [],
184184
"source": [
@@ -189,15 +189,15 @@
189189
},
190190
{
191191
"cell_type": "markdown",
192-
"id": "surrounded-plain",
192+
"id": "12",
193193
"metadata": {},
194194
"source": [
195195
"Grad is composable an arbitrary number of times. You can keep calling grad as many times as you like."
196196
]
197197
},
198198
{
199199
"cell_type": "markdown",
200-
"id": "brazilian-atlas",
200+
"id": "13",
201201
"metadata": {},
202202
"source": [
203203
"## Maximum likelihood estimation\n",
@@ -216,7 +216,7 @@
216216
{
217217
"cell_type": "code",
218218
"execution_count": null,
219-
"id": "confidential-sympathy",
219+
"id": "14",
220220
"metadata": {
221221
"tags": []
222222
},
@@ -236,7 +236,7 @@
236236
},
237237
{
238238
"cell_type": "markdown",
239-
"id": "atlantic-excellence",
239+
"id": "15",
240240
"metadata": {},
241241
"source": [
242242
"Our estimation task will necessitate calculating the total joint log likelihood of our data under a Gaussian model.\n",
@@ -248,7 +248,7 @@
248248
{
249249
"cell_type": "code",
250250
"execution_count": null,
251-
"id": "known-terrain",
251+
"id": "16",
252252
"metadata": {
253253
"tags": []
254254
},
@@ -263,7 +263,7 @@
263263
},
264264
{
265265
"cell_type": "markdown",
266-
"id": "terminal-census",
266+
"id": "17",
267267
"metadata": {},
268268
"source": [
269269
"If you're wondering why we use `log_sigma` rather than `sigma`, it is a choice made for practical reasons.\n",
@@ -280,7 +280,7 @@
280280
{
281281
"cell_type": "code",
282282
"execution_count": null,
283-
"id": "dominant-delight",
283+
"id": "18",
284284
"metadata": {
285285
"tags": []
286286
},
@@ -293,7 +293,7 @@
293293
},
294294
{
295295
"cell_type": "markdown",
296-
"id": "equal-brazilian",
296+
"id": "19",
297297
"metadata": {},
298298
"source": [
299299
"Now, we can create the gradient function of our negative log likelihood.\n",
@@ -307,7 +307,7 @@
307307
{
308308
"cell_type": "code",
309309
"execution_count": null,
310-
"id": "meaning-scanning",
310+
"id": "20",
311311
"metadata": {
312312
"tags": []
313313
},
@@ -322,7 +322,7 @@
322322
},
323323
{
324324
"cell_type": "markdown",
325-
"id": "hourly-miller",
325+
"id": "21",
326326
"metadata": {},
327327
"source": [
328328
"Now, we can do the gradient descent step!"
@@ -331,7 +331,7 @@
331331
{
332332
"cell_type": "code",
333333
"execution_count": null,
334-
"id": "cosmetic-perception",
334+
"id": "22",
335335
"metadata": {
336336
"tags": []
337337
},
@@ -347,15 +347,15 @@
347347
},
348348
{
349349
"cell_type": "markdown",
350-
"id": "defensive-family",
350+
"id": "23",
351351
"metadata": {},
352352
"source": [
353353
"And voila! We have gradient descended our way to the maximum likelihood parameters :)."
354354
]
355355
},
356356
{
357357
"cell_type": "markdown",
358-
"id": "constant-account",
358+
"id": "24",
359359
"metadata": {},
360360
"source": [
361361
"## Exercise: Where is the gold? It's at the minima!\n",
@@ -368,7 +368,7 @@
368368
{
369369
"cell_type": "code",
370370
"execution_count": null,
371-
"id": "focal-climate",
371+
"id": "25",
372372
"metadata": {
373373
"tags": []
374374
},
@@ -383,7 +383,7 @@
383383
},
384384
{
385385
"cell_type": "markdown",
386-
"id": "massive-corps",
386+
"id": "26",
387387
"metadata": {},
388388
"source": [
389389
"It should be evident from here that there are two minima in the function.\n",
@@ -398,7 +398,7 @@
398398
{
399399
"cell_type": "code",
400400
"execution_count": null,
401-
"id": "opened-beads",
401+
"id": "27",
402402
"metadata": {
403403
"tags": []
404404
},
@@ -420,7 +420,7 @@
420420
},
421421
{
422422
"cell_type": "markdown",
423-
"id": "brown-violation",
423+
"id": "28",
424424
"metadata": {},
425425
"source": [
426426
"Now, implement the optimization loop!"
@@ -429,7 +429,7 @@
429429
{
430430
"cell_type": "code",
431431
"execution_count": null,
432-
"id": "alternative-wisdom",
432+
"id": "29",
433433
"metadata": {
434434
"tags": []
435435
},
@@ -450,7 +450,7 @@
450450
},
451451
{
452452
"cell_type": "markdown",
453-
"id": "alternative-iraqi",
453+
"id": "30",
454454
"metadata": {},
455455
"source": [
456456
"## Exercise: programming a robot that only moves along one axis\n",
@@ -464,7 +464,7 @@
464464
{
465465
"cell_type": "code",
466466
"execution_count": null,
467-
"id": "operational-advantage",
467+
"id": "31",
468468
"metadata": {
469469
"tags": []
470470
},
@@ -490,7 +490,7 @@
490490
},
491491
{
492492
"cell_type": "markdown",
493-
"id": "ecological-asian",
493+
"id": "32",
494494
"metadata": {},
495495
"source": [
496496
"For your reference we have the function plotted below."
@@ -499,7 +499,7 @@
499499
{
500500
"cell_type": "code",
501501
"execution_count": null,
502-
"id": "convenient-optics",
502+
"id": "33",
503503
"metadata": {
504504
"tags": []
505505
},
@@ -531,7 +531,7 @@
531531
{
532532
"cell_type": "code",
533533
"execution_count": null,
534-
"id": "loaded-labor",
534+
"id": "34",
535535
"metadata": {},
536536
"outputs": [],
537537
"source": []

Diff for: src/dl_workshop/answers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Answers to the main tutorial notebooks.
33
"""
4+
45
import jax.numpy as np
56
import numpy.random as npr
67
from jax import grad
@@ -81,7 +82,7 @@ def logistic_loss(params, model, x, y):
8182

8283

8384
def f(w):
84-
return w ** 2 + 3 * w - 5
85+
return w**2 + 3 * w - 5
8586

8687

8788
def df(w):

Diff for: src/dl_workshop/jax_idioms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def randomness_ex_3(key, num_realizations: int, grw_draw: Callable):
9393

9494
def goldfield(x, y):
9595
"""All credit to https://www.analyzemath.com/calculus/multivariable/maxima_minima.html for this function."""
96-
return (2 * x ** 2) - (4 * x * y) + (y ** 4 + 2)
96+
return (2 * x**2) - (4 * x * y) + (y**4 + 2)
9797

9898

9999
def grad_ex_1():

Diff for: src/setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Setup script."""
2+
23
from setuptools import find_packages, setup
34

45
setup(name="dl_workshop", version="0.1", packages=find_packages())

0 commit comments

Comments
 (0)