Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 61 additions & 35 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
mpmath_name="sqrt",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
dict(name="multiply", namespace="stablehlo", is_binary=True, is_full_expansion=True,
passes='--stablehlo-complex-math-expander="enable-full-expansion=true"'),
]


Expand Down Expand Up @@ -145,11 +147,15 @@ def main():

flush_subnormals = False
for op in operations:
is_full = op.get("is_full_expansion", False)
is_binary = op.get("is_binary", False)

opname = op["name"]
mpmath_opname = op.get("mpmath_name", opname)
namespace = op.get("namespace", "chlo")
size_re = size_im = op.get("size", default_size)
passes = op.get("passes", "--chlo-legalize-to-stablehlo")

for dtype in [np.complex64, np.complex128, np.float32, np.float64]:
params = fa.utils.function_validation_parameters(opname, dtype)
max_ulp_difference = op.get(
Expand All @@ -167,11 +173,6 @@ def main():
flush_subnormals=flush_subnormals,
)

fi = np.finfo(dtype)

float_dtype = to_float_dtype[dtype]
finfo = np.finfo(float_dtype)

if dtype in [np.complex64, np.complex128]:
samples = fa.utils.complex_samples(
size=(size_re, size_im),
Expand All @@ -185,57 +186,82 @@ def main():
include_subnormal=not flush_subnormals,
).flatten()

samples = np.concatenate((samples, fa.utils.extra_samples(opname, dtype)))

expected = getattr(nmp, mpmath_opname).call(samples,
enable_progressbar=True)
expected = np.array(expected, dtype)
if is_full:
# 1. Filter out non-finite inputs
samples = samples[np.isfinite(samples)]

if is_binary:
samples_lhs = samples
samples_rhs = np.roll(samples, 1)

with np.errstate(all='ignore'):
expected = getattr(np, opname)(samples_lhs, samples_rhs).astype(dtype)

# 2. Filter out pairs that result in non-finite outputs (Overflows)
finite_mask = np.isfinite(expected)
samples_lhs = samples_lhs[finite_mask]
samples_rhs = samples_rhs[finite_mask]
expected = expected[finite_mask]
else:
with np.errstate(all='ignore'):
expected = getattr(np, opname)(samples).astype(dtype)
finite_mask = np.isfinite(expected)
samples = samples[finite_mask]
expected = expected[finite_mask]
else:
# Accuracy check: Include Inf/NaN/Subnormals
samples = np.concatenate((samples, fa.utils.extra_samples(opname, dtype)))
expected = getattr(nmp, mpmath_opname).call(samples)
expected = np.array(expected, dtype)

module_name = f"{opname}_{dtype.__name__}"
m = SSA.make_module(module_name)

samples_func = m.make_function("samples", "", mlir_type(samples))
samples_func.assign(samples)
samples_func.return_last()
if is_binary and is_full:
m.make_function("samples_lhs", "", mlir_type(samples_lhs)).assign(samples_lhs)
m.blocks[-1].return_last()
m.make_function("samples_rhs", "", mlir_type(samples_rhs)).assign(samples_rhs)
m.blocks[-1].return_last()
else:
m.make_function("samples", "", mlir_type(samples)).assign(samples)
m.blocks[-1].return_last()

expected_func = m.make_function("expected", "", mlir_type(expected))
expected_func.assign(expected)
expected_func.return_last()

main_func = m.make_function("main", "", "", "public")

ref_samples = main_func.call("samples")
actual = main_func.composite(f"{namespace}.{opname}", ref_samples)
expected = main_func.call("expected")
if is_binary and is_full:
ref_lhs = main_func.call("samples_lhs")
ref_rhs = main_func.call("samples_rhs")
actual = main_func.composite(f"{namespace}.{opname}", ref_lhs, ref_rhs)
else:
ref_samples = main_func.call("samples")
actual = main_func.composite(f"{namespace}.{opname}", ref_samples)

expected_val = main_func.call("expected")
main_func.void_call(
"check.expect_close",
actual,
expected,
expected_val,
f"max_ulp_difference = {max_ulp_difference}",
atypes=", ".join(map(main_func.get_ref_type, [actual, expected])),
atypes=", ".join(map(main_func.get_ref_type, [actual, expected_val])),
)
main_func.void_call("func.return")

source = str(m).rstrip() + "\n"
fname = os.path.join(target_dir, f"{module_name}.mlir")
if os.path.isfile(fname):
f = open(fname, "r")
content = f.read()
f.close()
if content.endswith(source):
print(f"{fname} is up-to-date.")
continue

f = open(fname, "w")
f.write(
f"// RUN: stablehlo-opt {passes} %s |"
" stablehlo-translate --interpret\n"
)
f.write(
"// This file is generated, see build_tools/math/README.md for more"
" information.\n")
f.write(source)
f.close()
with open(fname, "r") as f:
if f.read().endswith(source):
print(f"{fname} is up-to-date.")
continue

with open(fname, "w") as f:
f.write(f"// RUN: stablehlo-opt {passes} %s | stablehlo-translate --interpret\n")
f.write("// This file is generated, see build_tools/math/README.md for more information.\n")
f.write(source)
print(f"Created {fname}")

# Testing ULP difference
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5273,4 +5273,4 @@ func.func @scan(%arg0: tensor<2xi32>, %arg1: tensor<i32>) -> (tensor<2xi32>, ten
stablehlo.return %1, %1 : tensor<i32>, tensor<i32>
} : (tensor<2xi32>, tensor<i32>) -> (tensor<2xi32>, tensor<i32>)
func.return %0#0, %0#1 : tensor<2xi32>, tensor<i32>
}
}
24 changes: 24 additions & 0 deletions stablehlo/tests/math/multiply_complex128.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: stablehlo-opt --stablehlo-complex-math-expander="enable-full-expansion=true" %s | stablehlo-translate --interpret
// This file is generated, see build_tools/math/README.md for more information.
module @multiply_complex128 {
func.func private @samples_lhs() -> tensor<42xcomplex<f64>> {
%0 = stablehlo.constant dense<"0x000000000000FC9F000000000000F8BF0100000000000080000000000000F8BF0000000000000000000000000000F8BF0100000000000000000000000000F8BF000000000000FC1F000000000000F8BF000000000000F83F000000000000F8BF000000000000FC9F000000000000FC9F0100000000000080000000000000FC9F0000000000000000000000000000FC9F0100000000000000000000000000FC9F000000000000FC1F000000000000FC9F000000000000F83F000000000000FC9F000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F0100000000000080000000000000F83F0100000000000080000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F0000000000000000000000000000F83F0000000000000000000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F0100000000000000000000000000F83F0100000000000000000000000000FC9F000000000000FC1F0100000000000080000000000000FC1F0000000000000000000000000000FC1F0100000000000000000000000000FC1F000000000000FC1F000000000000FC1F000000000000F83F000000000000FC1F000000000000FC9F000000000000F83F0100000000000080000000000000F83F0000000000000000000000000000F83F0100000000000000000000000000F83F000000000000FC1F000000000000F83F000000000000F83F000000000000F83F"> : tensor<42xcomplex<f64>>
return %0 : tensor<42xcomplex<f64>>
}
func.func private @samples_rhs() -> tensor<42xcomplex<f64>> {
%0 = stablehlo.constant dense<"0x000000000000F8BF000000000000F8BF000000000000FC9F000000000000F8BF0100000000000080000000000000F8BF0000000000000000000000000000F8BF0100000000000000000000000000F8BF000000000000FC1F000000000000F8BF000000000000F8BF000000000000FC9F000000000000FC9F000000000000FC9F0100000000000080000000000000FC9F0000000000000000000000000000FC9F0100000000000000000000000000FC9F000000000000FC1F000000000000FC9F000000000000F8BF0100000000000080000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F0100000000000080000000000000F8BF0000000000000000000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F0000000000000000000000000000F8BF0100000000000000000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F0100000000000000000000000000F8BF000000000000FC1F000000000000FC9F000000000000FC1F0100000000000080000000000000FC1F0000000000000000000000000000FC1F0100000000000000000000000000FC1F000000000000FC1F000000000000FC1F000000000000F8BF000000000000F83F000000000000FC9F000000000000F83F0100000000000080000000000000F83F0000000000000000000000000000F83F0100000000000000000000000000F83F000000000000FC1F000000000000F83F"> : tensor<42xcomplex<f64>>
return %0 : tensor<42xcomplex<f64>>
}
func.func private @expected() -> tensor<42xcomplex<f64>> {
%0 = stablehlo.constant dense<"0x00000000000002C0000000000000024000000000000002C0000000000000052000000000000002C0020000000000000000000000000002C0020000000000008000000000000002C000000000000005A000000000000002C000000000000002C0000000000000052000000000000005200000000000400C800000000000400C000000000000400C8000000000000000000000000000400C8000000000000000800000000000400C800000000000400C80000000000000052000000000000005A0000000000000052002000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000800000000000000000000000000000008000000000000005200200000000000080000000000000052000000000000000800000000000000000000000000000008000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005200000000000000000000000000000052002000000000000800000000000000000000000000000008000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000005200200000000000000000000000000052000000000000005A00000000000400C800000000000400C800000000000400C8000000000000000000000000000400C8000000000000000000000000000400C800000000000400C000000000000000520000000000000052000000000000002C000000000000002C000000000000002C000000000000005A000000000000002C0020000000000008000000000000002C0020000000000000000000000000002C0000000000000052000000000000002C00000000000000240"> : tensor<42xcomplex<f64>>
return %0 : tensor<42xcomplex<f64>>
}
func.func public @main() {
%0 = call @samples_lhs() : () -> tensor<42xcomplex<f64>>
%1 = call @samples_rhs() : () -> tensor<42xcomplex<f64>>
%2 = "stablehlo.multiply"(%0, %1) : (tensor<42xcomplex<f64>>, tensor<42xcomplex<f64>>) -> tensor<42xcomplex<f64>>
%3 = call @expected() : () -> tensor<42xcomplex<f64>>
check.expect_close %2, %3, max_ulp_difference = 3 : tensor<42xcomplex<f64>>, tensor<42xcomplex<f64>>
func.return
}
}
24 changes: 24 additions & 0 deletions stablehlo/tests/math/multiply_complex64.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: stablehlo-opt --stablehlo-complex-math-expander="enable-full-expansion=true" %s | stablehlo-translate --interpret
// This file is generated, see build_tools/math/README.md for more information.
module @multiply_complex64 {
func.func private @samples_lhs() -> tensor<42xcomplex<f32>> {
%0 = stablehlo.constant dense<"0x0000E09F0000C0BF010000800000C0BF000000000000C0BF010000000000C0BF0000E01F0000C0BF0000C03F0000C0BF0000E09F0000E09F010000800000E09F000000000000E09F010000000000E09F0000E01F0000E09F0000C03F0000E09F0000E09F010000800100008001000080000000000100008001000000010000800000E01F010000800000C03F010000800000E09F000000000100008000000000000000000000000001000000000000000000E01F000000000000C03F000000000000E09F010000000100008001000000000000000100000001000000010000000000E01F010000000000C03F010000000000E09F0000E01F010000800000E01F000000000000E01F010000000000E01F0000E01F0000E01F0000C03F0000E01F0000E09F0000C03F010000800000C03F000000000000C03F010000000000C03F0000E01F0000C03F0000C03F0000C03F"> : tensor<42xcomplex<f32>>
return %0 : tensor<42xcomplex<f32>>
}
func.func private @samples_rhs() -> tensor<42xcomplex<f32>> {
%0 = stablehlo.constant dense<"0x0000C0BF0000C0BF0000E09F0000C0BF010000800000C0BF000000000000C0BF010000000000C0BF0000E01F0000C0BF0000C0BF0000E09F0000E09F0000E09F010000800000E09F000000000000E09F010000000000E09F0000E01F0000E09F0000C0BF010000800000E09F010000800100008001000080000000000100008001000000010000800000E01F010000800000C0BF000000000000E09F000000000100008000000000000000000000000001000000000000000000E01F000000000000C0BF010000000000E09F010000000100008001000000000000000100000001000000010000000000E01F010000000000C0BF0000E01F0000E09F0000E01F010000800000E01F000000000000E01F010000000000E01F0000E01F0000E01F0000C0BF0000C03F0000E09F0000C03F010000800000C03F000000000000C03F010000000000C03F0000E01F0000C03F"> : tensor<42xcomplex<f32>>
return %0 : tensor<42xcomplex<f32>>
}
func.func private @expected() -> tensor<42xcomplex<f32>> {
%0 = stablehlo.constant dense<"0x000010C000001040000010C000002820000010C002000000000010C002000080000010C0000028A0000010C0000010C00000282000002820000062800000620000006280000000000000628000000080000062800000628000002820000028A000002820020000000000000000000000000000800000000000000000000000800000000000000080000028200200008000002820000000800000000000000080000000800000000000000000000000000000000000000000000028200000000000002820020000800000000000000080000000800000000000000000000000000000000000000000000028200200000000002820000028A000006280000062800000628000000000000062800000000000006280000062000000282000002820000010C0000010C0000010C0000028A0000010C002000080000010C002000000000010C000002820000010C000001040"> : tensor<42xcomplex<f32>>
return %0 : tensor<42xcomplex<f32>>
}
func.func public @main() {
%0 = call @samples_lhs() : () -> tensor<42xcomplex<f32>>
%1 = call @samples_rhs() : () -> tensor<42xcomplex<f32>>
%2 = "stablehlo.multiply"(%0, %1) : (tensor<42xcomplex<f32>>, tensor<42xcomplex<f32>>) -> tensor<42xcomplex<f32>>
%3 = call @expected() : () -> tensor<42xcomplex<f32>>
check.expect_close %2, %3, max_ulp_difference = 3 : tensor<42xcomplex<f32>>, tensor<42xcomplex<f32>>
func.return
}
}
24 changes: 24 additions & 0 deletions stablehlo/tests/math/multiply_float32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: stablehlo-opt --stablehlo-complex-math-expander="enable-full-expansion=true" %s | stablehlo-translate --interpret
// This file is generated, see build_tools/math/README.md for more information.
module @multiply_float32 {
func.func private @samples_lhs() -> tensor<124xf32> {
%0 = stablehlo.constant dense<"0x44ED4ADDC7E0BCDB4AD42EDACDC7A0D850BB12D7D3AE84D557A2F6D3DA9568D25D89DAD0E07C4CCF6370BECDE66330CC6957A2CAED4A14C9703E86C7F331F8C576256AC4F918DCC27C0C4EC10000C0BF83F331BE06E7A3BC89DA15BB0CCE87B98FC1F9B712B56BB696A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3396A8DD3412B56B368FC1F9370CCE873989DA153B06E7A33C83F3313E0000C03F7C0C4E41F918DC4276256A44F331F845703E8647ED4A14496957A24AE663304C6370BE4DE07C4C4F5D89DA50DA95685257A2F653D3AE845550BB1257CDC7A0584AD42E5AC7E0BC5B44ED4A5DC1F9D85E"> : tensor<124xf32>
return %0 : tensor<124xf32>
}
func.func private @samples_rhs() -> tensor<124xf32> {
%0 = stablehlo.constant dense<"0xC1F9D8DE44ED4ADDC7E0BCDB4AD42EDACDC7A0D850BB12D7D3AE84D557A2F6D3DA9568D25D89DAD0E07C4CCF6370BECDE66330CC6957A2CAED4A14C9703E86C7F331F8C576256AC4F918DCC27C0C4EC10000C0BF83F331BE06E7A3BC89DA15BB0CCE87B98FC1F9B712B56BB696A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3396A8DD3412B56B368FC1F9370CCE873989DA153B06E7A33C83F3313E0000C03F7C0C4E41F918DC4276256A44F331F845703E8647ED4A14496957A24AE663304C6370BE4DE07C4C4F5D89DA50DA95685257A2F653D3AE845550BB1257CDC7A0584AD42E5AC7E0BC5B44ED4A5D"> : tensor<124xf32>
return %0 : tensor<124xf32>
}
func.func private @expected() -> tensor<124xf32> {
%0 = stablehlo.constant dense<"0x2BFEAB7C6BB8957968FD8076439A5B732E4F38709319186D45A8FF697D13E066688CC6631090AE60741E985D9537835AE4B65F5719143C54C6861B51AE26024EFA01E34A094FC947D426B1445D899A41A276853E47DD633BC2E23F38B6FD1E35187E043256F5E52E8916CC2B7AC2B32826F99C258EBA8722660D681F29BB431C657E221962DA061693EDE812EAE2CE0FFD62B60CCD6D9F0959038A0645476C03A8CE6300124C01004A0400000F0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000F0000004A040000124C0100A8CE630045476C0359038A06CD6D9F09FD62B60CEAE2CE0F93EDE81262DA0616657E221929BB431C660D681F8EBA872226F99C257AC2B3288916CC2B56F5E52E187E0432B6FD1E35C2E23F3847DD633BA276853E5D899A41D426B144094FC947FA01E34AAE26024EC6861B5119143C54E4B65F579537835A741E985D1090AE60688CC6637D13E06645A8FF699319186D2E4F3870439A5B7368FD80766BB895792BFEAB7C"> : tensor<124xf32>
return %0 : tensor<124xf32>
}
func.func public @main() {
%0 = call @samples_lhs() : () -> tensor<124xf32>
%1 = call @samples_rhs() : () -> tensor<124xf32>
%2 = "stablehlo.multiply"(%0, %1) : (tensor<124xf32>, tensor<124xf32>) -> tensor<124xf32>
%3 = call @expected() : () -> tensor<124xf32>
check.expect_close %2, %3, max_ulp_difference = 3 : tensor<124xf32>, tensor<124xf32>
func.return
}
}
Loading