@@ -1469,10 +1469,22 @@ def _float_int_cast(
14691469 dst_element_type = ir .IntegerType (_element_type (dst_type ))
14701470 if dst_element_type .width == 1 :
14711471 return _not_equal (src , _full (src .type , 0 ), signed = signed )
1472- elif signed :
1473- return arith_dialect .fptosi (dst_type , src )
14741472 else :
1475- return arith_dialect .fptoui (dst_type , src )
1473+ # We clamp the float value to the min/max integer destination value
1474+ # in order to match JAX/XLA casting behavior. Note that this differs
1475+ # from numpy casting behavior.
1476+ if signed :
1477+ maxint = 2 ** (dst_element_type .width - 1 ) - 1
1478+ minint = - 2 ** (dst_element_type .width - 1 )
1479+ else :
1480+ maxint = 2 ** dst_element_type .width - 1
1481+ minint = 0
1482+ src = arith_dialect .minimumf (src , _full (src .type , maxint ))
1483+ src = arith_dialect .maximumf (src , _full (src .type , minint ))
1484+ if signed :
1485+ return arith_dialect .fptosi (dst_type , src )
1486+ else :
1487+ return arith_dialect .fptoui (dst_type , src )
14761488
14771489
14781490def _int_float_cast (
@@ -1499,10 +1511,12 @@ def _cast(
14991511 src ,
15001512 _dtype_to_ir_type (dst_type ),
15011513 signed = jnp .issubdtype (src_type , jnp .signedinteger ),
1514+ dst_signed = jnp .issubdtype (dst_type , jnp .signedinteger ),
15021515 )
15031516
15041517
1505- def _ir_cast (src : ir .Value , dst_type : ir .Type , * , signed : bool ) -> ir .Value :
1518+ def _ir_cast (src : ir .Value , dst_type : ir .Type , * ,
1519+ signed : bool , dst_signed : bool = False ) -> ir .Value :
15061520 if ir .RankedTensorType .isinstance (
15071521 src .type
15081522 ) and not ir .RankedTensorType .isinstance (dst_type ):
@@ -1527,7 +1541,8 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
15271541 dst_element_type , ir .F32Type
15281542 ):
15291543 return _ir_cast (
1530- _ir_cast (src , ir .F32Type .get (), signed = False ), dst_type , signed = False
1544+ _ir_cast (src , ir .F32Type .get (), signed = False ),
1545+ dst_type , signed = False , dst_signed = dst_signed
15311546 )
15321547
15331548 if isinstance (src_element_type , ir .FloatType ) and isinstance (
@@ -1543,7 +1558,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
15431558 if isinstance (src_element_type , ir .FloatType ) and isinstance (
15441559 dst_element_type , ir .IntegerType
15451560 ):
1546- return _float_int_cast (src , dst_type , signed = signed )
1561+ return _float_int_cast (src , dst_type , signed = dst_signed )
15471562 if isinstance (src_element_type , ir .IntegerType ) and isinstance (
15481563 dst_element_type , ir .FloatType
15491564 ):
0 commit comments