Skip to content

Commit 1eb4946

Browse files
committed
Add brain image
1 parent d3126b5 commit 1eb4946

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

examples/image/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@
2929
# --------------------------------------------------------------------------------------
3030

3131
parser = argparse.ArgumentParser()
32-
parser.add_argument("--im", type=str, default="tree", choices=["shepp", "leaf", "tree"]) # [to do] add other distributions
32+
parser.add_argument(
33+
"--im",
34+
type=str,
35+
default="tree",
36+
choices=["shepp", "leaf", "tree", "brain"]
37+
)
3338
parser.add_argument("--im-blur", type=float, default=0.0)
3439
parser.add_argument("--im-pad", type=int, default=0)
3540
parser.add_argument("--im-res", type=int, default=256)
@@ -409,4 +414,4 @@ def evaluate_model(model: ment.MENT) -> dict:
409414
ax.set_xticks([])
410415
ax.set_yticks([])
411416
plt.savefig(os.path.join(output_dir, "fig_compare_all.png"))
412-
plt.close()
417+
plt.close()

examples/image/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@ def get_grid_points(coords: list[np.ndarray]) -> np.ndarray:
77
return np.vstack([c.ravel() for c in np.meshgrid(*coords, indexing="ij")]).T
88

99

10-
def gen_image(key: str, res: int, blur: float = 0.0, pad: int = 0) -> None:
10+
def gen_image(key: str, res: int = None, blur: float = 0.0, pad: int = 0) -> None:
1111
images = None
1212

1313
if key == "shepp":
1414
image = skimage.data.shepp_logan_phantom()
1515
image = image[::-1, :]
16-
# image = image.T
16+
image = image.T
17+
18+
elif key == "brain":
19+
images = skimage.data.brain()
20+
image = images[len(images) // 2]
21+
image = image[::-1, :]
22+
image = image.T
1723

1824
else:
1925
filenames = {
@@ -40,8 +46,9 @@ def gen_image(key: str, res: int, blur: float = 0.0, pad: int = 0) -> None:
4046
new_image[pad:-pad, pad:-pad] = image.copy()
4147
image = new_image.copy()
4248

43-
shape = (res, res)
44-
image = skimage.transform.resize(image, shape, anti_aliasing=True)
49+
if res:
50+
shape = (res, res)
51+
image = skimage.transform.resize(image, shape, anti_aliasing=True)
4552

4653
if blur:
4754
image = skimage.filters.gaussian(image, blur)

0 commit comments

Comments
 (0)