Skip to content

Commit b8dc1c5

Browse files
committed
Edit plotting
1 parent 1905f0b commit b8dc1c5

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

examples/image/train.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
# --------------------------------------------------------------------------------------
3030

3131
parser = argparse.ArgumentParser()
32-
parser.add_argument("--im", type=str, default="shepp", choices=["shepp", "leaf", "tree"])
32+
parser.add_argument("--im", type=str, default="tree", choices=["shepp", "leaf", "tree"]) # [to do] add other distributions
3333
parser.add_argument("--im-blur", type=float, default=0.0)
3434
parser.add_argument("--im-pad", type=int, default=0)
3535
parser.add_argument("--im-res", type=int, default=256)
36-
parser.add_argument("--nmeas", type=int, default=50)
36+
parser.add_argument("--nmeas", type=int, default=25)
3737
parser.add_argument("--angle-max", type=float, default=180.0)
3838
parser.add_argument("--angle-min", type=float, default=0.0)
39-
parser.add_argument("--iters", type=int, default=20)
39+
parser.add_argument("--iters", type=int, default=10)
4040
parser.add_argument("--lr", type=float, default=0.25)
4141
parser.add_argument("--prior-scale", type=float, default=10.0)
4242
parser.add_argument("--int-loop", type=int, default=0)
@@ -373,11 +373,20 @@ def evaluate_model(model: ment.MENT) -> dict:
373373
results["ment"]["sinogram"] = sinogram_pred.copy()
374374

375375

376-
# Compare
376+
# Compare
377+
scale = 1.0
378+
for name in results:
379+
image = results[name]["image"]
380+
image = image / np.sum(image)
381+
results[name]["image"] = np.copy(image)
382+
scale = max(scale, np.max(image))
383+
for name in results:
384+
results[name]["image"] /= scale
385+
377386
fig, axs = plt.subplots(ncols=4, figsize=(10, 2.5), sharex=True, sharey=True)
378387
for ax, key in zip(axs, results):
379388
image = results[key]["image"]
380-
ax.pcolormesh(image.T)
389+
ax.pcolormesh(image.T, vmin=0.0, vmax=1.0)
381390
ax.set_title(key.upper())
382391
plt.savefig(os.path.join(output_dir, "fig_compare_image.png"))
383392
plt.close()

0 commit comments

Comments
 (0)