-
Notifications
You must be signed in to change notification settings - Fork 750
Open
Labels
err:performancePerformance issuesPerformance issuesstat:awaiting openxla-engAwaiting response from openxla-engAwaiting response from openxla-eng
Description
First reported at jax-ml/jax#34751.
StableHLO convert spec says:
For boolean-to-any-supported-type conversions, the value false is converted to zero, and the value true is converted to one.
However, for non-standard boolean representations, XLA does not conform to this when converting to integer:
>>> import jax
>>> import numpy as np
>>> nonstandard_bool = jax.numpy.asarray(np.frombuffer(b'\x00\x01\xff', dtype=bool))
>>> print(nonstandard_bool)
[False True True]
>>> print(jax.lax.convert_element_type(nonstandard_bool, 'int32'))
[ 0 1 255]
# HLO of the above:
>>> print(jax.jit(lambda x: jax.lax.convert_element_type(x, 'int32')).lower(nonstandard_bool).as_text())
module @jit__lambda attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi1>) -> (tensor<3xi32> {jax.result_info = "result"}) {
%0 = stablehlo.convert %arg0 : (tensor<3xi1>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
}Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
err:performancePerformance issuesPerformance issuesstat:awaiting openxla-engAwaiting response from openxla-engAwaiting response from openxla-eng