Skip to content

Commit 2ff04aa

Browse files
committed
FIXES: Pre-commit, Formatting, Unit Tests
Signed-off-by: btrentini <brunoxtrentini@gmail.com>
1 parent 6d3dab2 commit 2ff04aa

File tree

4 files changed

+120
-94
lines changed

4 files changed

+120
-94
lines changed

sub-packages/bionemo-moco/examples/entropic_time_scheduler_tutorial_cfm.ipynb

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@
3434
"metadata": {},
3535
"outputs": [],
3636
"source": [
37-
"import torch \n",
38-
"import math\n",
39-
"import numpy as np \n",
4037
"import matplotlib.pyplot as plt\n",
38+
"import torch\n",
4139
"from sklearn.datasets import make_moons"
4240
]
4341
},
@@ -126,6 +124,8 @@
126124
],
127125
"source": [
128126
"normalize = True\n",
127+
"\n",
128+
"\n",
129129
"def sample_moons(n, normalize=False):\n",
130130
" x1, _ = make_moons(n_samples=n, noise=0.05)\n",
131131
" x1 = torch.Tensor(x1)\n",
@@ -134,6 +134,7 @@
134134
" x1 = (x1 - x1.mean(0)) / x1.std(0) * 2\n",
135135
" return x1\n",
136136
"\n",
137+
"\n",
137138
"x1 = sample_moons(1000)\n",
138139
"plt.scatter(x1[:, 0], x1[:, 1])"
139140
]
@@ -211,11 +212,10 @@
211212
"metadata": {},
212213
"outputs": [],
213214
"source": [
214-
"import torch\n",
215215
"import matplotlib.pyplot as plt\n",
216-
"from bionemo.moco.schedules.inference_time_schedules import (\n",
217-
" LinearInferenceSchedule, TimeDirection\n",
218-
")"
216+
"import torch\n",
217+
"\n",
218+
"from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule, TimeDirection"
219219
]
220220
},
221221
{
@@ -245,12 +245,13 @@
245245
"source": [
246246
"%%time\n",
247247
"# ---------- parameters ----------\n",
248-
"dim = 2\n",
249-
"shape = (batch_size, dim)\n",
250-
"_FLOW_STEPS = 100 # number of steps\n",
251-
"display_int = 10 # show every n-th step\n",
248+
"dim = 2\n",
249+
"shape = (batch_size, dim)\n",
250+
"_FLOW_STEPS = 100 # number of steps\n",
251+
"display_int = 10 # show every n-th step\n",
252252
"# --------------------------------\n",
253253
"\n",
254+
"\n",
254255
"def square_centre_limits(ax, pts, pad_frac: float = 0.05):\n",
255256
" \"\"\"Make the axes square and centred on the data.\n",
256257
"\n",
@@ -261,18 +262,16 @@
261262
" \"\"\"\n",
262263
" x, y = pts[:, 0], pts[:, 1]\n",
263264
" x_mid, y_mid = (x.max() + x.min()) / 2, (y.max() + y.min()) / 2\n",
264-
" half_range = max(x.max() - x.min(), y.max() - y.min()) / 2\n",
265-
" half_range *= (1 + pad_frac) # add a small margin\n",
265+
" half_range = max(x.max() - x.min(), y.max() - y.min()) / 2\n",
266+
" half_range *= 1 + pad_frac # add a small margin\n",
266267
" ax.set_xlim(x_mid - half_range, x_mid + half_range)\n",
267268
" ax.set_ylim(y_mid - half_range, y_mid + half_range)\n",
268269
"\n",
270+
"\n",
269271
"# define schedule\n",
270-
"inference_sched = LinearInferenceSchedule(\n",
271-
" nsteps=_FLOW_STEPS,\n",
272-
" direction=TimeDirection.UNIFIED\n",
273-
")\n",
274-
"schedule = inference_sched.generate_schedule().to(DEVICE) # len = _FLOW_STEPS\n",
275-
"dts = inference_sched.discretize().to(DEVICE) # len = _FLOW_STEPS\n",
272+
"inference_sched = LinearInferenceSchedule(nsteps=_FLOW_STEPS, direction=TimeDirection.UNIFIED)\n",
273+
"schedule = inference_sched.generate_schedule().to(DEVICE) # len = _FLOW_STEPS\n",
274+
"dts = inference_sched.discretize().to(DEVICE) # len = _FLOW_STEPS\n",
276275
"\n",
277276
"# always show t=0 and t=1\n",
278277
"display_indices = sorted(set(range(0, _FLOW_STEPS + 1, display_int)) | {0, _FLOW_STEPS})\n",
@@ -281,22 +280,23 @@
281280
"with torch.no_grad():\n",
282281
" # start from the prior used in training\n",
283282
" x = cfm.sample_prior(shape).to(DEVICE)\n",
284-
" \n",
283+
"\n",
285284
" fig, axes = plt.subplots(1, n_plots, figsize=(4 * n_plots, 4))\n",
286285
" for ax in axes:\n",
287-
" ax.set_aspect('equal', 'box')\n",
288-
" ax.set_xticks([]); ax.set_yticks([])\n",
286+
" ax.set_aspect(\"equal\", \"box\")\n",
287+
" ax.set_xticks([])\n",
288+
" ax.set_yticks([])\n",
289289
"\n",
290290
" plot_idx = 0\n",
291291
" axes[plot_idx].scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n",
292-
" axes[plot_idx].set_title('t = 0.00')\n",
292+
" axes[plot_idx].set_title(\"t = 0.00\")\n",
293293
" square_centre_limits(axes[plot_idx], x.cpu())\n",
294294
"\n",
295295
" # sampling loop\n",
296296
" for step, (dt, t) in enumerate(zip(dts, schedule)):\n",
297297
" full_t = inference_sched.pad_time(batch_size, t, device=DEVICE)\n",
298-
" v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n",
299-
" x = cfm.step(v_t, x, dt, t=full_t)\n",
298+
" v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n",
299+
" x = cfm.step(v_t, x, dt, t=full_t)\n",
300300
"\n",
301301
" # time after the step (always exists, even at the very end)\n",
302302
" t_next = (t + dt).item()\n",
@@ -305,11 +305,11 @@
305305
" plot_idx += 1\n",
306306
" ax = axes[plot_idx]\n",
307307
" ax.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n",
308-
" ax.set_title(f't = {t_next:.2f}')\n",
308+
" ax.set_title(f\"t = {t_next:.2f}\")\n",
309309
" square_centre_limits(ax, x.cpu())\n",
310310
"\n",
311311
"plt.tight_layout(pad=0.8)\n",
312-
"plt.show()\n"
312+
"plt.show()"
313313
]
314314
},
315315
{
@@ -325,12 +325,11 @@
325325
"metadata": {},
326326
"outputs": [],
327327
"source": [
328+
"import matplotlib.pyplot as plt\n",
328329
"import torch\n",
329330
"from torch import Tensor\n",
330-
"import matplotlib.pyplot as plt\n",
331-
"from bionemo.moco.schedules.inference_time_schedules import (\n",
332-
" EntropicInferenceSchedule, TimeDirection\n",
333-
")"
331+
"\n",
332+
"from bionemo.moco.schedules.inference_time_schedules import EntropicInferenceSchedule, TimeDirection"
334333
]
335334
},
336335
{
@@ -367,10 +366,11 @@
367366
],
368367
"source": [
369368
"%%time\n",
370-
"_FLOW_STEPS = 100\n",
371-
"display_int = 10 # controls every \"n\" steps to display\n",
369+
"_FLOW_STEPS = 100\n",
370+
"display_int = 10 # controls every \"n\" steps to display\n",
372371
"shape = (batch_size, dim)\n",
373372
"\n",
373+
"\n",
374374
"# Predictor function wrapper.\n",
375375
"# The scheduler needs a function `model(t, x)` and this wrapper handles the formatting.\n",
376376
"def predictor_fn(t: Tensor, x: Tensor) -> Tensor:\n",
@@ -379,62 +379,65 @@
379379
" t = t.unsqueeze(-1)\n",
380380
" if t.shape[0] != x.shape[0]:\n",
381381
" t = t.expand(x.shape[0], -1)\n",
382-
" \n",
382+
"\n",
383383
" model_input = torch.cat([x, t], dim=-1)\n",
384384
" return model(model_input)\n",
385385
"\n",
386+
"\n",
386387
"def x_0_sampler_fn(n_samples: int) -> Tensor:\n",
387388
" return cfm.sample_prior((n_samples, dim))\n",
388389
"\n",
390+
"\n",
389391
"def x_1_sampler_fn(n_samples: int) -> Tensor:\n",
390392
" return sample_moons(n_samples)\n",
391393
"\n",
394+
"\n",
392395
"inference_sched = EntropicInferenceSchedule(\n",
393396
" predictor=predictor_fn,\n",
394397
" x_0_sampler=x_0_sampler_fn,\n",
395398
" x_1_sampler=x_1_sampler_fn,\n",
396399
" nsteps=_FLOW_STEPS,\n",
397-
" n_approx_entropy_points=30, # More points -> more accurate schedule, but slower to generate\n",
400+
" n_approx_entropy_points=30, # More points -> more accurate schedule, but slower to generate\n",
398401
" batch_size=batch_size,\n",
399402
" direction=TimeDirection.UNIFIED,\n",
400403
" device=DEVICE,\n",
401404
")\n",
402405
"print(\"Generating entropic schedule...\")\n",
403-
"schedule = inference_sched.generate_schedule().to(DEVICE) \n",
404-
"dts = inference_sched.discretize().to(DEVICE)\n",
406+
"schedule = inference_sched.generate_schedule().to(DEVICE)\n",
407+
"dts = inference_sched.discretize().to(DEVICE)\n",
405408
"print(\"Schedule generated.\")\n",
406409
"\n",
407410
"display_indices = sorted(set(range(0, _FLOW_STEPS + 1, display_int)) | {0, _FLOW_STEPS})\n",
408411
"n_plots = len(display_indices)\n",
409412
"\n",
410413
"with torch.no_grad():\n",
411-
"\n",
412414
" x = cfm.sample_prior((batch_size, dim)).to(DEVICE)\n",
413415
"\n",
414416
" fig, axes = plt.subplots(1, n_plots, figsize=(4 * n_plots, 4))\n",
415417
" for ax in axes:\n",
416-
" ax.set_aspect('equal', 'box')\n",
417-
" ax.set_xticks([]); ax.set_yticks([])\n",
418+
" ax.set_aspect(\"equal\", \"box\")\n",
419+
" ax.set_xticks([])\n",
420+
" ax.set_yticks([])\n",
418421
"\n",
419422
" plot_idx = 0\n",
420423
" axes[plot_idx].scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n",
421-
" axes[plot_idx].set_title('t = 0.00')\n",
422-
" square_centre_limits(axes[plot_idx], x.cpu()) \n",
424+
" axes[plot_idx].set_title(\"t = 0.00\")\n",
425+
" square_centre_limits(axes[plot_idx], x.cpu())\n",
423426
"\n",
424427
" # integration loop & viz\n",
425428
" for step, (dt, t) in enumerate(zip(dts, schedule)):\n",
426429
" full_t = inference_sched.pad_time(batch_size, t, device=DEVICE)\n",
427-
" v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n",
428-
" x = cfm.step(v_t, x, dt, t=full_t)\n",
430+
" v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n",
431+
" x = cfm.step(v_t, x, dt, t=full_t)\n",
429432
"\n",
430433
" t_next = (t + dt).item()\n",
431434
"\n",
432435
" if (step + 1) in display_indices:\n",
433436
" plot_idx += 1\n",
434437
" ax = axes[plot_idx]\n",
435438
" ax.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n",
436-
" ax.set_title(f't = {t_next:.2f}')\n",
437-
" square_centre_limits(ax, x.cpu()) \n",
439+
" ax.set_title(f\"t = {t_next:.2f}\")\n",
440+
" square_centre_limits(ax, x.cpu())\n",
438441
"\n",
439442
"plt.tight_layout(pad=0.8)\n",
440443
"plt.show()"

0 commit comments

Comments
 (0)