@@ -1700,6 +1700,7 @@ def circle_image(
17001700 size: Spatial size in the order ``(X, Y)``.
17011701 shape: Spatial size in the order ``(Y, X)``.
17021702 num: Number ``N`` of images in batch.
1703+ If zero, return a single unbatched image tensor.
17031704 center: Coordinates of center pixel in the order ``(x, y)``.
17041705 radius: Radius of circle in pixel units.
17051706 sigma: Standard deviation of isotropic Gaussian blurring kernel in pixel units.
@@ -1712,7 +1713,7 @@ def circle_image(
17121713 device: Device on which to create image tensor.
17131714
17141715 Returns:
1715- Image tensor of shape ``(N, 1, Y, X)``.
1716+ Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``) .
17161717
17171718 """
17181719 size = _image_size ("circle_image" , size , shape , ndim = 2 )
@@ -1769,6 +1770,7 @@ def cshape_image(
17691770 size: Spatial size in the order ``(X, Y)``.
17701771 shape: Spatial size in the order ``(Y, X)``.
17711772 num: Number ``N`` of images in batch.
1773+ If zero, return a single unbatched image tensor.
17721774 center: Coordinates of center pixel in the order ``(y, x)``.
17731775 radius: Radius of circle in pixel units.
17741776 width: Difference between outer and inner circle radius.
@@ -1784,7 +1786,7 @@ def cshape_image(
17841786 device: Device on which to create image tensor.
17851787
17861788 Returns:
1787- Image tensor of shape ``(N, 1, Y, X)``.
1789+ Image tensor of shape ``(N, 1, Y, X)`` or ``(1, Y, X)`` (``num=0``) .
17881790
17891791 """
17901792 size = _image_size ("cshape_image" , size , shape , ndim = 2 )
@@ -1832,17 +1834,21 @@ def empty_image(
18321834 size: Spatial size in the order ``(X, ...)``.
18331835 shape: Spatial size in the order ``(..., X)``.
18341836 num: Number of images in batch.
1837+ If zero, return a single unbatched image tensor.
18351838 channels: Number of channels per image.
18361839 dtype: Data type of image tensor.
18371840 device: Device on which to store image data.
18381841
18391842 Returns:
1840- Uninitialized image batch tensor .
1843+ Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) .
18411844
18421845 """
18431846 size = _image_size ("empty_image" , size , shape )
18441847 shape = (num or 1 , channels or 1 ) + tuple (reversed (size ))
1845- return torch .empty (shape , dtype = dtype , device = device )
1848+ data = torch .empty (shape , dtype = dtype , device = device )
1849+ if num == 0 :
1850+ data = data .squeeze_ (0 )
1851+ return data
18461852
18471853
18481854def grid_image (
@@ -1861,6 +1867,7 @@ def grid_image(
18611867 shape: Spatial size in the order ``(..., X)``.
18621868 num: Number of images in batch. When ``shape`` is not a ``Grid``, must
18631869 match the size of the first dimension in ``shape`` if not ``None``.
1870+ If zero, return a single unbatched image tensor.
18641871 stride: Spacing between grid lines. To draw in-plane grid lines on a
18651872 D-dimensional image where ``D>2``, specify a sequence of two stride
18661873 values, where the first stride applies to the last tensor dimension,
@@ -1870,7 +1877,7 @@ def grid_image(
18701877 device: Device on which to store image data.
18711878
18721879 Returns:
1873- Image tensor of shape ``(N, 1, ..., X)``. The default number of channels is 1 .
1880+ Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) .
18741881
18751882 """
18761883 size = _image_size ("grid_image" , size , shape )
@@ -1889,7 +1896,12 @@ def grid_image(
18891896 n = data .shape [dim ]
18901897 index = torch .arange ((n % step ) // 2 , n , step , dtype = torch .int64 , device = data .device )
18911898 data .index_fill_ (dim , index , 0 if inverted else 1 )
1892- return data .expand (num or 1 , * data .shape [1 :])
1899+ if num is not None :
1900+ if num == 0 :
1901+ data = data .squeeze_ (0 )
1902+ elif num > 1 :
1903+ data = data .expand ((1 ,) + data .shape [1 :])
1904+ return data
18931905
18941906
18951907def ones_image (
@@ -1911,7 +1923,7 @@ def ones_image(
19111923 device: Device on which to store image data.
19121924
19131925 Returns:
1914- Image batch tensor filled with ones.
1926+ Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with ones.
19151927
19161928 """
19171929 size = _image_size ("ones_image" , size , shape )
@@ -1938,7 +1950,7 @@ def zeros_image(
19381950 device: Device on which to store image data.
19391951
19401952 Returns:
1941- Image batch tensor filled with zeros.
1953+ Image tensor of shape ``(N, 1, ..., X)`` or ``(1, ..., X)`` (``num=0``) filled with zeros.
19421954
19431955 """
19441956 size = _image_size ("zeros_image" , size , shape )
0 commit comments