Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,44 @@ XDSL_LINALG_OPT_VARIANTS = [
"linalg_5_xdsl", # should run the same passes as linalg_xdsl but via a fully expanded pipeline instead of xdsl-opt test passes/mini-pipelines
]

# Max-bits-lost variants: linalg_xdsl_bN -> max_bits_lost = N
# (number of low-order mantissa bits of the result the polynomial is allowed to corrupt)
# N = -1 -> correctly-rounded,
# N = 0 -> libm-grade,
# N > 0 -> relaxed accuracy bound
XDSL_LINALG_MAX_BITS_LOST_VARIANTS = [
"linalg_xdsl_b-1",
"linalg_xdsl_b0",
"linalg_xdsl_b1",
"linalg_xdsl_b2",
"linalg_xdsl_b3",
"linalg_xdsl_b4",
"linalg_xdsl_b5",
"linalg_xdsl_b6",
"linalg_xdsl_b7",
"linalg_xdsl_b8",
"linalg_xdsl_b9",
"linalg_xdsl_b10",
"linalg_xdsl_b11",
"linalg_xdsl_b12",
"linalg_xdsl_b13",
"linalg_xdsl_b14",
"linalg_xdsl_b15",
"linalg_xdsl_b16",
]

# f16 has 11 mantissa bits: bN with N >= 11 asks the polynomial to corrupt
# more bits than the type has, which makes the Chebyshev fit degenerate
# (degree 0 -> ZeroDivisionError in the pass).
XDSL_LINALG_MAX_BITS_LOST_VARIANTS_F16 = [
v for v in XDSL_LINALG_MAX_BITS_LOST_VARIANTS
if int(v.rsplit("_b", 1)[1]) < 11
]

XDSL_LINALG_VARIANTS = [
"linalg_xdsl", # xDSL lowering from linalg on tensors
*XDSL_LINALG_OPT_VARIANTS,
*XDSL_LINALG_MAX_BITS_LOST_VARIANTS,
]

XDSL_VARIANTS = [
Expand Down Expand Up @@ -146,7 +181,7 @@ TESTSET_FAST = [
"exp_micro/{N}xf{precision}/{variant}",
N=range(16, 65, 16),
precision=[16, 32, 64],
variant=["baseline"],
variant=["baseline", "linalg_xdsl_b4"],
),
*expand(
"exp_macro/{N}xf{precision}/{variant}",
Expand Down Expand Up @@ -210,11 +245,23 @@ TESTSET_LOW_LEVEL_REPRESENTATION = [
TESTSET_EXP_MICRO = [
*expand(
"exp_micro/{N}xf{precision}/{variant}",
N=range(26, 129, 16),
N=range(16, 129, 16),
precision=[16, 32, 64],
variant=["baseline"],
),
*expand(
"exp_micro/{N}xf16/{variant}",
N=range(16, 129, 16),
variant=XDSL_LINALG_MAX_BITS_LOST_VARIANTS_F16,
),
*expand(
"exp_micro/{N}xf{precision}/{variant}",
N=range(16, 129, 16),
precision=[32, 64],
variant=XDSL_LINALG_MAX_BITS_LOST_VARIANTS,
),
]

TESTSET_EXP_MACRO = [
*expand(
"exp_macro/{N}xf{precision}/{variant}",
Expand Down Expand Up @@ -740,15 +787,54 @@ rule xdsl_kernel_generate_source:
"kernels/{kernel}/{shape}/{variant}.xdsl.mlir",
wildcard_constraints:
kernel="|".join(KERNEL_TEMPLATES),
variant="|".join(XDSL_LINALG_VARIANTS),
variant="|".join(v for v in XDSL_LINALG_VARIANTS if v not in XDSL_LINALG_MAX_BITS_LOST_VARIANTS),
params:
format_template="scripts/format.py",
xdsl_opt=config["xdsl-opt"],
mlir_opt=config["mlir-opt"],
mlir_opt_flags_linalg=config["mlir-opt-flags-linalg"],
shell:
"""
python3 {params.format_template} {input.template} {input.json} \
| {params.mlir_opt} {params.mlir_opt_flags_linalg} \
| sed 's/arith.maxf/arith.maximumf/g' \
| {params.xdsl_opt} -p arith-add-fastmath \
| sed 's/arith.maximumf/arith.maxf/g' > {output}
"""


def get_exp_attrs_from_variant(wildcards):
"""Return math.exp attribute string for a `linalg_xdsl_b<N>` variant,
where <N> is the integer max_bits_lost (signed)."""
import re
m = re.search(r"_b(-?\d+)$", wildcards.variant)
if m:
return (
f"max_bits_lost = {int(m.group(1))} : i64, "
f"lower_bound = -2.0 : f64, upper_bound = 0.0 : f64"
)
raise ValueError(f"Cannot extract exp attributes from variant: {wildcards.variant}")


rule xdsl_kernel_generate_source_exp_attrs:
input:
json="kernels/{kernel}/{shape}/params.json",
template="kernels/{kernel}/linalg.mlir.template",
output:
"kernels/{kernel}/{shape}/{variant}.xdsl.mlir",
wildcard_constraints:
kernel="|".join(KERNEL_TEMPLATES),
variant="|".join(XDSL_LINALG_MAX_BITS_LOST_VARIANTS),
params:
format_template="scripts/format.py",
xdsl_opt=config["xdsl-opt"],
mlir_opt=config["mlir-opt"],
mlir_opt_flags_linalg=config["mlir-opt-flags-linalg"],
exp_attrs=get_exp_attrs_from_variant,
shell:
"""
python3 {params.format_template} {input.template} {input.json} \
| sed 's/math.exp %\\([^ ]*\\) :/math.exp %\\1 {{{params.exp_attrs}}} :/g' \
| {params.mlir_opt} {params.mlir_opt_flags_linalg} \
| sed 's/arith.maxf/arith.maximumf/g' \
| {params.xdsl_opt} -p arith-add-fastmath \
Expand Down
4 changes: 2 additions & 2 deletions kernels/exp_micro/gendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def exp_data(
t = {16: np.float16, 32: np.float32, 64: np.float64}[precision]

# Clamp range to avoid overflow in exp
rmin = max(rmin, -10.0)
rmax = min(rmax, 10.0)
rmin = max(rmin, -2.0)
rmax = min(rmax, 0.0)

np.random.seed(0)
x = np.random.uniform(rmin, rmax, N).astype(t)
Expand Down
17 changes: 17 additions & 0 deletions kernels/exp_micro/linalg.mlir.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module {
func.func public @exp_kernel(
%arg0: memref<{{N}}xf{{precision}}> {llvm.noalias},
%arg1: memref<{{N}}xf{{precision}}> {llvm.noalias}) -> memref<{{N}}xf{{precision}}> {
linalg.generic {
indexing_maps = [
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%arg0 : memref<{{N}}xf{{precision}}>) outs(%arg1 : memref<{{N}}xf{{precision}}>) {
^bb0(%in: f{{precision}}, %out: f{{precision}}):
%0 = math.exp %in : f{{precision}}
linalg.yield %0 : f{{precision}}
}
return %arg1 : memref<{{N}}xf{{precision}}>
}
}
7 changes: 3 additions & 4 deletions kernels/exp_micro/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ int main() {
snrt_fpu_fence();
(void)snrt_mcycle();

// Correctness check
// Correctness check: We emit this part because some variants allow large to fully wrong results
int nerr = 0;
for (int i = 0; i < N; i++) {
DTYPE d = FABSF(local_z[i] - G[i]);
DTYPE ref = FABSF(G[i]);
// Use relative error for large values, absolute for small
DTYPE tol = ref > (DTYPE)1.0 ? ref * (DTYPE)1E-2 : (DTYPE)1E-2;
nerr += !(d <= tol);
// DTYPE tol = ref > (DTYPE)1.0 ? ref * (DTYPE)1E-2 : (DTYPE)1E-2;
// nerr += !(d <= tol);
}
return nerr;
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ dependencies = [
"seaborn",
"numpy",
"snakemake==8.14.0",
"xdsl @ git+https://github.com/xdslproject/xdsl.git@3c59f76216deed2674bdd2db64b572831b8b034a",
"xdsl @ git+https://github.com/szerdick/xdsl.git@b7fd02b6c010d1e601d0f66d8daa0cd85994ea37",
]
Loading
Loading