Skip to content

ConvertElementType bool-to-int returns incorrect result for nonstandard booleans #37159

@jakevdp

Description

@jakevdp

First reported at jax-ml/jax#34751.

StableHLO convert spec says:

For boolean-to-any-supported-type conversions, the value false is converted to zero, and the value true is converted to one.

However, for non-standard boolean representations, XLA does not conform to this when converting to integer:

>>> import jax
>>> import numpy as np
>>> nonstandard_bool = jax.numpy.asarray(np.frombuffer(b'\x00\x01\xff', dtype=bool))
>>> print(nonstandard_bool)
[False  True  True]

>>> print(jax.lax.convert_element_type(nonstandard_bool, 'int32'))
[  0   1 255]

# HLO of the above:
>>> print(jax.jit(lambda x: jax.lax.convert_element_type(x, 'int32')).lower(nonstandard_bool).as_text())
module @jit__lambda attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi1>) -> (tensor<3xi32> {jax.result_info = "result"}) {
    %0 = stablehlo.convert %arg0 : (tensor<3xi1>) -> tensor<3xi32>
    return %0 : tensor<3xi32>
  }
}

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions