|
29 | 29 | # -------------------------------------------------------------------------------------- |
30 | 30 |
|
31 | 31 | parser = argparse.ArgumentParser() |
32 | | -parser.add_argument("--im", type=str, default="shepp", choices=["shepp", "leaf", "tree"]) |
| 32 | +parser.add_argument("--im", type=str, default="tree", choices=["shepp", "leaf", "tree"]) # [to do] add other distributions |
33 | 33 | parser.add_argument("--im-blur", type=float, default=0.0) |
34 | 34 | parser.add_argument("--im-pad", type=int, default=0) |
35 | 35 | parser.add_argument("--im-res", type=int, default=256) |
36 | | -parser.add_argument("--nmeas", type=int, default=50) |
| 36 | +parser.add_argument("--nmeas", type=int, default=25) |
37 | 37 | parser.add_argument("--angle-max", type=float, default=180.0) |
38 | 38 | parser.add_argument("--angle-min", type=float, default=0.0) |
39 | | -parser.add_argument("--iters", type=int, default=20) |
| 39 | +parser.add_argument("--iters", type=int, default=10) |
40 | 40 | parser.add_argument("--lr", type=float, default=0.25) |
41 | 41 | parser.add_argument("--prior-scale", type=float, default=10.0) |
42 | 42 | parser.add_argument("--int-loop", type=int, default=0) |
@@ -373,11 +373,20 @@ def evaluate_model(model: ment.MENT) -> dict: |
373 | 373 | results["ment"]["sinogram"] = sinogram_pred.copy() |
374 | 374 |
|
375 | 375 |
|
376 | | -# Compare |
| 376 | +# Compare |
| 377 | +scale = 1.0 |
| 378 | +for name in results: |
| 379 | + image = results[name]["image"] |
| 380 | + image = image / np.sum(image) |
| 381 | + results[name]["image"] = np.copy(image) |
| 382 | + scale = max(scale, np.max(image)) |
| 383 | +for name in results: |
| 384 | + results[name]["image"] /= scale |
| 385 | + |
377 | 386 | fig, axs = plt.subplots(ncols=4, figsize=(10, 2.5), sharex=True, sharey=True) |
378 | 387 | for ax, key in zip(axs, results): |
379 | 388 | image = results[key]["image"] |
380 | | - ax.pcolormesh(image.T) |
| 389 | + ax.pcolormesh(image.T, vmin=0.0, vmax=1.0) |
381 | 390 | ax.set_title(key.upper()) |
382 | 391 | plt.savefig(os.path.join(output_dir, "fig_compare_image.png")) |
383 | 392 | plt.close() |
|
0 commit comments