@@ -184,9 +184,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
184184 else :
185185 if core .constexpr (input .dtype .primitive_bitwidth ) < core .constexpr (32 ):
186186 if core .constexpr (input .dtype .is_floating ()):
187- # Do not promote f16 to f32 as it has native hardware support
188- if not core .constexpr (input .dtype == core .float16 ):
189- input = input .to (core .float32 )
187+ input = input .to (core .float32 )
190188 else :
191189 assert input .dtype .is_int (), "Expecting input to be integer type"
192190 input = input .to (core .int32 )
@@ -243,11 +241,9 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
243241 else :
244242 return core ._reduce_with_indices (input , axis , _argmin_combine_tie_break_fast , keep_dims = keep_dims )
245243 else :
246- if core .constexpr (input .dtype .primitive_bitwidth ) < core . constexpr ( 32 ) :
244+ if core .constexpr (input .dtype .primitive_bitwidth ) < 32 :
247245 if core .constexpr (input .dtype .is_floating ()):
248- # Do not promote f16 to f32 as it has native hardware support
249- if not core .constexpr (input .dtype == core .float16 ):
250- input = input .to (core .float32 )
246+ input = input .to (core .float32 )
251247 else :
252248 assert input .dtype .is_int (), "Expecting input to be integer type"
253249 input = input .to (core .int32 )
0 commit comments