|
34 | 34 | "metadata": {}, |
35 | 35 | "outputs": [], |
36 | 36 | "source": [ |
37 | | - "import torch \n", |
38 | | - "import math\n", |
39 | | - "import numpy as np \n", |
40 | 37 | "import matplotlib.pyplot as plt\n", |
| 38 | + "import torch\n", |
41 | 39 | "from sklearn.datasets import make_moons" |
42 | 40 | ] |
43 | 41 | }, |
|
126 | 124 | ], |
127 | 125 | "source": [ |
128 | 126 | "normalize = True\n", |
| 127 | + "\n", |
| 128 | + "\n", |
129 | 129 | "def sample_moons(n, normalize=False):\n", |
130 | 130 | " x1, _ = make_moons(n_samples=n, noise=0.05)\n", |
131 | 131 | " x1 = torch.Tensor(x1)\n", |
|
134 | 134 | " x1 = (x1 - x1.mean(0)) / x1.std(0) * 2\n", |
135 | 135 | " return x1\n", |
136 | 136 | "\n", |
| 137 | + "\n", |
137 | 138 | "x1 = sample_moons(1000)\n", |
138 | 139 | "plt.scatter(x1[:, 0], x1[:, 1])" |
139 | 140 | ] |
|
211 | 212 | "metadata": {}, |
212 | 213 | "outputs": [], |
213 | 214 | "source": [ |
214 | | - "import torch\n", |
215 | 215 | "import matplotlib.pyplot as plt\n", |
216 | | - "from bionemo.moco.schedules.inference_time_schedules import (\n", |
217 | | - " LinearInferenceSchedule, TimeDirection\n", |
218 | | - ")" |
| 216 | + "import torch\n", |
| 217 | + "\n", |
| 218 | + "from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule, TimeDirection" |
219 | 219 | ] |
220 | 220 | }, |
221 | 221 | { |
|
245 | 245 | "source": [ |
246 | 246 | "%%time\n", |
247 | 247 | "# ---------- parameters ----------\n", |
248 | | - "dim = 2\n", |
249 | | - "shape = (batch_size, dim)\n", |
250 | | - "_FLOW_STEPS = 100 # number of steps\n", |
251 | | - "display_int = 10 # show every n-th step\n", |
| 248 | + "dim = 2\n", |
| 249 | + "shape = (batch_size, dim)\n", |
| 250 | + "_FLOW_STEPS = 100 # number of steps\n", |
| 251 | + "display_int = 10 # show every n-th step\n", |
252 | 252 | "# --------------------------------\n", |
253 | 253 | "\n", |
| 254 | + "\n", |
254 | 255 | "def square_centre_limits(ax, pts, pad_frac: float = 0.05):\n", |
255 | 256 | " \"\"\"Make the axes square and centred on the data.\n", |
256 | 257 | "\n", |
|
261 | 262 | " \"\"\"\n", |
262 | 263 | " x, y = pts[:, 0], pts[:, 1]\n", |
263 | 264 | " x_mid, y_mid = (x.max() + x.min()) / 2, (y.max() + y.min()) / 2\n", |
264 | | - " half_range = max(x.max() - x.min(), y.max() - y.min()) / 2\n", |
265 | | - " half_range *= (1 + pad_frac) # add a small margin\n", |
| 265 | + " half_range = max(x.max() - x.min(), y.max() - y.min()) / 2\n", |
| 266 | + " half_range *= 1 + pad_frac # add a small margin\n", |
266 | 267 | " ax.set_xlim(x_mid - half_range, x_mid + half_range)\n", |
267 | 268 | " ax.set_ylim(y_mid - half_range, y_mid + half_range)\n", |
268 | 269 | "\n", |
| 270 | + "\n", |
269 | 271 | "# define schedule\n", |
270 | | - "inference_sched = LinearInferenceSchedule(\n", |
271 | | - " nsteps=_FLOW_STEPS,\n", |
272 | | - " direction=TimeDirection.UNIFIED\n", |
273 | | - ")\n", |
274 | | - "schedule = inference_sched.generate_schedule().to(DEVICE) # len = _FLOW_STEPS\n", |
275 | | - "dts = inference_sched.discretize().to(DEVICE) # len = _FLOW_STEPS\n", |
| 272 | + "inference_sched = LinearInferenceSchedule(nsteps=_FLOW_STEPS, direction=TimeDirection.UNIFIED)\n", |
| 273 | + "schedule = inference_sched.generate_schedule().to(DEVICE) # len = _FLOW_STEPS\n", |
| 274 | + "dts = inference_sched.discretize().to(DEVICE) # len = _FLOW_STEPS\n", |
276 | 275 | "\n", |
277 | 276 | "# always show t=0 and t=1\n", |
278 | 277 | "display_indices = sorted(set(range(0, _FLOW_STEPS + 1, display_int)) | {0, _FLOW_STEPS})\n", |
|
281 | 280 | "with torch.no_grad():\n", |
282 | 281 | " # start from the prior used in training\n", |
283 | 282 | " x = cfm.sample_prior(shape).to(DEVICE)\n", |
284 | | - " \n", |
| 283 | + "\n", |
285 | 284 | " fig, axes = plt.subplots(1, n_plots, figsize=(4 * n_plots, 4))\n", |
286 | 285 | " for ax in axes:\n", |
287 | | - " ax.set_aspect('equal', 'box')\n", |
288 | | - " ax.set_xticks([]); ax.set_yticks([])\n", |
| 286 | + " ax.set_aspect(\"equal\", \"box\")\n", |
| 287 | + " ax.set_xticks([])\n", |
| 288 | + " ax.set_yticks([])\n", |
289 | 289 | "\n", |
290 | 290 | " plot_idx = 0\n", |
291 | 291 | " axes[plot_idx].scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n", |
292 | | - " axes[plot_idx].set_title('t = 0.00')\n", |
| 292 | + " axes[plot_idx].set_title(\"t = 0.00\")\n", |
293 | 293 | " square_centre_limits(axes[plot_idx], x.cpu())\n", |
294 | 294 | "\n", |
295 | 295 | " # sampling loop\n", |
296 | 296 | " for step, (dt, t) in enumerate(zip(dts, schedule)):\n", |
297 | 297 | " full_t = inference_sched.pad_time(batch_size, t, device=DEVICE)\n", |
298 | | - " v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n", |
299 | | - " x = cfm.step(v_t, x, dt, t=full_t)\n", |
| 298 | + " v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n", |
| 299 | + " x = cfm.step(v_t, x, dt, t=full_t)\n", |
300 | 300 | "\n", |
301 | 301 | " # time after the step (always exists, even at the very end)\n", |
302 | 302 | " t_next = (t + dt).item()\n", |
|
305 | 305 | " plot_idx += 1\n", |
306 | 306 | " ax = axes[plot_idx]\n", |
307 | 307 | " ax.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n", |
308 | | - " ax.set_title(f't = {t_next:.2f}')\n", |
| 308 | + " ax.set_title(f\"t = {t_next:.2f}\")\n", |
309 | 309 | " square_centre_limits(ax, x.cpu())\n", |
310 | 310 | "\n", |
311 | 311 | "plt.tight_layout(pad=0.8)\n", |
312 | | - "plt.show()\n" |
| 312 | + "plt.show()" |
313 | 313 | ] |
314 | 314 | }, |
315 | 315 | { |
|
325 | 325 | "metadata": {}, |
326 | 326 | "outputs": [], |
327 | 327 | "source": [ |
| 328 | + "import matplotlib.pyplot as plt\n", |
328 | 329 | "import torch\n", |
329 | 330 | "from torch import Tensor\n", |
330 | | - "import matplotlib.pyplot as plt\n", |
331 | | - "from bionemo.moco.schedules.inference_time_schedules import (\n", |
332 | | - " EntropicInferenceSchedule, TimeDirection\n", |
333 | | - ")" |
| 331 | + "\n", |
| 332 | + "from bionemo.moco.schedules.inference_time_schedules import EntropicInferenceSchedule, TimeDirection" |
334 | 333 | ] |
335 | 334 | }, |
336 | 335 | { |
|
367 | 366 | ], |
368 | 367 | "source": [ |
369 | 368 | "%%time\n", |
370 | | - "_FLOW_STEPS = 100\n", |
371 | | - "display_int = 10 # controls every \"n\" steps to display\n", |
| 369 | + "_FLOW_STEPS = 100\n", |
| 370 | + "display_int = 10 # controls every \"n\" steps to display\n", |
372 | 371 | "shape = (batch_size, dim)\n", |
373 | 372 | "\n", |
| 373 | + "\n", |
374 | 374 | "# Predictor function wrapper.\n", |
375 | 375 | "# The scheduler needs a function `model(t, x)` and this wrapper handles the formatting.\n", |
376 | 376 | "def predictor_fn(t: Tensor, x: Tensor) -> Tensor:\n", |
|
379 | 379 | " t = t.unsqueeze(-1)\n", |
380 | 380 | " if t.shape[0] != x.shape[0]:\n", |
381 | 381 | " t = t.expand(x.shape[0], -1)\n", |
382 | | - " \n", |
| 382 | + "\n", |
383 | 383 | " model_input = torch.cat([x, t], dim=-1)\n", |
384 | 384 | " return model(model_input)\n", |
385 | 385 | "\n", |
| 386 | + "\n", |
386 | 387 | "def x_0_sampler_fn(n_samples: int) -> Tensor:\n", |
387 | 388 | " return cfm.sample_prior((n_samples, dim))\n", |
388 | 389 | "\n", |
| 390 | + "\n", |
389 | 391 | "def x_1_sampler_fn(n_samples: int) -> Tensor:\n", |
390 | 392 | " return sample_moons(n_samples)\n", |
391 | 393 | "\n", |
| 394 | + "\n", |
392 | 395 | "inference_sched = EntropicInferenceSchedule(\n", |
393 | 396 | " predictor=predictor_fn,\n", |
394 | 397 | " x_0_sampler=x_0_sampler_fn,\n", |
395 | 398 | " x_1_sampler=x_1_sampler_fn,\n", |
396 | 399 | " nsteps=_FLOW_STEPS,\n", |
397 | | - " n_approx_entropy_points=30, # More points -> more accurate schedule, but slower to generate\n", |
| 400 | + " n_approx_entropy_points=30, # More points -> more accurate schedule, but slower to generate\n", |
398 | 401 | " batch_size=batch_size,\n", |
399 | 402 | " direction=TimeDirection.UNIFIED,\n", |
400 | 403 | " device=DEVICE,\n", |
401 | 404 | ")\n", |
402 | 405 | "print(\"Generating entropic schedule...\")\n", |
403 | | - "schedule = inference_sched.generate_schedule().to(DEVICE) \n", |
404 | | - "dts = inference_sched.discretize().to(DEVICE)\n", |
| 406 | + "schedule = inference_sched.generate_schedule().to(DEVICE)\n", |
| 407 | + "dts = inference_sched.discretize().to(DEVICE)\n", |
405 | 408 | "print(\"Schedule generated.\")\n", |
406 | 409 | "\n", |
407 | 410 | "display_indices = sorted(set(range(0, _FLOW_STEPS + 1, display_int)) | {0, _FLOW_STEPS})\n", |
408 | 411 | "n_plots = len(display_indices)\n", |
409 | 412 | "\n", |
410 | 413 | "with torch.no_grad():\n", |
411 | | - "\n", |
412 | 414 | " x = cfm.sample_prior((batch_size, dim)).to(DEVICE)\n", |
413 | 415 | "\n", |
414 | 416 | " fig, axes = plt.subplots(1, n_plots, figsize=(4 * n_plots, 4))\n", |
415 | 417 | " for ax in axes:\n", |
416 | | - " ax.set_aspect('equal', 'box')\n", |
417 | | - " ax.set_xticks([]); ax.set_yticks([])\n", |
| 418 | + " ax.set_aspect(\"equal\", \"box\")\n", |
| 419 | + " ax.set_xticks([])\n", |
| 420 | + " ax.set_yticks([])\n", |
418 | 421 | "\n", |
419 | 422 | " plot_idx = 0\n", |
420 | 423 | " axes[plot_idx].scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n", |
421 | | - " axes[plot_idx].set_title('t = 0.00')\n", |
422 | | - " square_centre_limits(axes[plot_idx], x.cpu()) \n", |
| 424 | + " axes[plot_idx].set_title(\"t = 0.00\")\n", |
| 425 | + " square_centre_limits(axes[plot_idx], x.cpu())\n", |
423 | 426 | "\n", |
424 | 427 | " # integration loop & viz\n", |
425 | 428 | " for step, (dt, t) in enumerate(zip(dts, schedule)):\n", |
426 | 429 | " full_t = inference_sched.pad_time(batch_size, t, device=DEVICE)\n", |
427 | | - " v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n", |
428 | | - " x = cfm.step(v_t, x, dt, t=full_t)\n", |
| 430 | + " v_t = model(torch.cat([x, full_t[:, None]], dim=-1))\n", |
| 431 | + " x = cfm.step(v_t, x, dt, t=full_t)\n", |
429 | 432 | "\n", |
430 | 433 | " t_next = (t + dt).item()\n", |
431 | 434 | "\n", |
432 | 435 | " if (step + 1) in display_indices:\n", |
433 | 436 | " plot_idx += 1\n", |
434 | 437 | " ax = axes[plot_idx]\n", |
435 | 438 | " ax.scatter(x[:, 0].cpu(), x[:, 1].cpu(), s=2)\n", |
436 | | - " ax.set_title(f't = {t_next:.2f}')\n", |
437 | | - " square_centre_limits(ax, x.cpu()) \n", |
| 439 | + " ax.set_title(f\"t = {t_next:.2f}\")\n", |
| 440 | + " square_centre_limits(ax, x.cpu())\n", |
438 | 441 | "\n", |
439 | 442 | "plt.tight_layout(pad=0.8)\n", |
440 | 443 | "plt.show()" |
|
0 commit comments