Skip to content

Commit d51bbf2

Browse files
authored
[core] Create unbatched image tensors when num=0 (#103)
1 parent 125b3a9 commit d51bbf2

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

src/deepali/core/flow.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,10 @@ def warp_image(
429429
def zeros_flow(
430430
size: Optional[Union[int, Size, Grid]] = None,
431431
shape: Optional[Shape] = None,
432-
num: int = 1,
433-
named: bool = False,
432+
num: Optional[int] = None,
434433
dtype: Optional[DType] = None,
435434
device: Optional[Device] = None,
436435
) -> Tensor:
437436
r"""Create batch of flow fields filled with zeros for given image batch size or grid."""
438437
size = _image_size("zeros_flow", size, shape)
439-
return zeros_image(size, num=num, channels=len(size), named=named, dtype=dtype, device=device)
438+
return zeros_image(size, num=num, channels=len(size), dtype=dtype, device=device)

src/deepali/core/image.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,7 @@ def circle_image(
17001700
size: Spatial size in the order ``(X, Y)``.
17011701
shape: Spatial size in the order ``(Y, X)``.
17021702
num: Number ``N`` of images in batch.
1703+
If zero, return a single unbatched image tensor.
17031704
center: Coordinates of center pixel in the order ``(x, y)``.
17041705
radius: Radius of circle in pixel units.
17051706
sigma: Standard deviation of isotropic Gaussian blurring kernel in pixel units.
@@ -1712,7 +1713,7 @@ def circle_image(
17121713
device: Device on which to create image tensor.
17131714
17141715
Returns:
1715-
Image tensor of shape ``(N, 1, Y, X)``.
1716+
Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``).
17161717
17171718
"""
17181719
size = _image_size("circle_image", size, shape, ndim=2)
@@ -1769,6 +1770,7 @@ def cshape_image(
17691770
size: Spatial size in the order ``(X, Y)``.
17701771
shape: Spatial size in the order ``(Y, X)``.
17711772
num: Number ``N`` of images in batch.
1773+
If zero, return a single unbatched image tensor.
17721774
center: Coordinates of center pixel in the order ``(y, x)``.
17731775
radius: Radius of circle in pixel units.
17741776
width: Difference between outer and inner circle radius.
@@ -1784,7 +1786,7 @@ def cshape_image(
17841786
device: Device on which to create image tensor.
17851787
17861788
Returns:
1787-
Image tensor of shape ``(N, 1, Y, X)``.
1789+
Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``).
17881790
17891791
"""
17901792
size = _image_size("cshape_image", size, shape, ndim=2)
@@ -1832,17 +1834,21 @@ def empty_image(
18321834
size: Spatial size in the order ``(X, ...)``.
18331835
shape: Spatial size in the order ``(..., X)``.
18341836
num: Number of images in batch.
1837+
If zero, return a single unbatched image tensor.
18351838
channels: Number of channels per image.
18361839
dtype: Data type of image tensor.
18371840
device: Device on which to store image data.
18381841
18391842
Returns:
1840-
Uninitialized image batch tensor.
1843+
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``).
18411844
18421845
"""
18431846
size = _image_size("empty_image", size, shape)
18441847
shape = (num or 1, channels or 1) + tuple(reversed(size))
1845-
return torch.empty(shape, dtype=dtype, device=device)
1848+
data = torch.empty(shape, dtype=dtype, device=device)
1849+
if num == 0:
1850+
data = data.squeeze_(0)
1851+
return data
18461852

18471853

18481854
def grid_image(
@@ -1861,6 +1867,7 @@ def grid_image(
18611867
shape: Spatial size in the order ``(..., X)``.
18621868
num: Number of images in batch. When ``shape`` is not a ``Grid``, must
18631869
match the size of the first dimension in ``shape`` if not ``None``.
1870+
If zero, return a single unbatched image tensor.
18641871
stride: Spacing between grid lines. To draw in-plane grid lines on a
18651872
D-dimensional image where ``D>2``, specify a sequence of two stride
18661873
values, where the first stride applies to the last tensor dimension,
@@ -1870,7 +1877,7 @@ def grid_image(
18701877
device: Device on which to store image data.
18711878
18721879
Returns:
1873-
Image tensor of shape ``(N, 1, ..., X)``. The default number of channels is 1.
1880+
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``).
18741881
18751882
"""
18761883
size = _image_size("grid_image", size, shape)
@@ -1889,7 +1896,12 @@ def grid_image(
18891896
n = data.shape[dim]
18901897
index = torch.arange((n % step) // 2, n, step, dtype=torch.int64, device=data.device)
18911898
data.index_fill_(dim, index, 0 if inverted else 1)
1892-
return data.expand(num or 1, *data.shape[1:])
1899+
if num is not None:
1900+
if num == 0:
1901+
data = data.squeeze_(0)
1902+
elif num > 1:
1903+
data = data.expand((1,) + data.shape[1:])
1904+
return data
18931905

18941906

18951907
def ones_image(
@@ -1911,7 +1923,7 @@ def ones_image(
19111923
device: Device on which to store image data.
19121924
19131925
Returns:
1914-
Image batch tensor filled with ones.
1926+
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with ones.
19151927
19161928
"""
19171929
size = _image_size("ones_image", size, shape)
@@ -1938,7 +1950,7 @@ def zeros_image(
19381950
device: Device on which to store image data.
19391951
19401952
Returns:
1941-
Image batch tensor filled with zeros.
1953+
Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with zeros.
19421954
19431955
"""
19441956
size = _image_size("zeros_image", size, shape)

0 commit comments

Comments
 (0)