@@ -141,6 +141,30 @@ def __init__(self, **data):
141141 elif self .image_array is not None :
142142 if not isinstance (self .image_array , np .ndarray ):
143143 gs .raise_exception ("`image_array` needs to be an numpy array." )
144+ if self .image_array .dtype != np .uint8 :
145+ if self .image_array .dtype in (np .float32 , np .float64 ):
146+ if self .image_array .max () <= 1.0 :
147+ self .image_array = (self .image_array * 255.0 ).round ()
148+ self .image_array = np .clip (self .image_array , 0.0 , 255.0 ).astype (np .uint8 )
149+ elif self .image_array .dtype == np .bool_ :
150+ self .image_array = self .image_array .astype (np .uint8 ) * 255
151+
152+ elif np .issubdtype (self .image_array .dtype , np .integer ):
153+ self .image_array = np .clip (self .image_array , 0 , 255 ).astype (np .uint8 )
154+ else :
155+ gs .raise_exception (
156+ f"Unsupported image dtype { self .image_array .dtype } . Only uint8 or float32/64 are supported."
157+ )
158+ if self .image_array .ndim == 2 :
159+ self .image_array = np .stack ([self .image_array ] * 3 , axis = - 1 )
160+
161+ elif self .image_array .shape [2 ] == 1 :
162+ self .image_array = np .repeat (self .image_array , 3 , axis = 2 )
163+
164+ elif self .image_array .shape [2 ] == 2 :
165+ L = self .image_array [..., 0 ]
166+ A = self .image_array [..., 1 ]
167+ self .image_array = np .stack ([L , L , L , A ], axis = - 1 )
144168
145169 # calculate channel
146170 if self .image_array is None :
@@ -163,8 +187,6 @@ def __init__(self, **data):
163187 if self .encoding not in ["srgb" , "linear" ]:
164188 gs .raise_exception (f"Invalid image encoding: { self .encoding } ." )
165189
166- assert self .image_array is None or self .image_array .dtype == np .uint8
167-
168190 def check_dim (self , dim ):
169191 if self .image_array is not None :
170192 if self ._channel > dim :
0 commit comments