Skip to content

New Codegen: Boolean Arguments Passed to NestedSDFGs #2393

Description

@philip-paul-mueller

Description

There is a big in the new GPU code gen affecting if blocks.

The affected code is generated by a where() expression:

    ret = where(c, a, b)

where a, b and c are all arrays.

The SDFG contains a Map with a nested SDFG such a:

Image

The problem now is that this code is lowered as:

DACE_DFI void if_stmt_0_0_0_5(const float* __restrict__ __arg1, const float* __restrict__ __arg2, const bool * __restrict__ __cond, float&  __output)
{
    if (__cond)
    {
        __output = __arg1[0];
    } else
    {
        __output = __arg2[0];
    }
}

As you can see __cond, the condition, i.e. c[...], is passed as pointer and as condition we have if(__cond).
Thus the true-branch is selected if c is allocated, which is most of the times the case and therefore the else branch is never taken.

In the old code generator, the signature of the function was different, instead of passing pointers the arguments where references.
So my guess is that the translation of the condition assumes that it is still passed as a variable instead of a pointer.

See the reproducer:

import dace
from dace.sdfg import nodes as dace_nodes


def _make_nested_sdfg():
    sdfg = dace.SDFG("nested")

    sdfg.add_scalar(
            "__cond", dtype=dace.bool_, transient=False
    )
    for name in ["__arg1", "__arg2", "__output"]:
        sdfg.add_scalar(
                name, dtype=dace.float64, transient=False
    )

    if_region = dace.sdfg.state.ConditionalBlock("if")
    sdfg.add_node(if_region, ensure_unique_name=True)

    then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg)
    tstate = then_body.add_state("true_branch", is_start_block=True)
    if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body)

    else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg)
    fstate = else_body.add_state("false_branch", is_start_block=True)
    if_region.add_branch(None, else_body)

    def _mk_copy(state, inp):
        state.add_nedge(
                state.add_access(inp),
                state.add_access("__output"),
                dace.Memlet(data=inp, subset="0")
        )

    _mk_copy(tstate, "__arg1")
    _mk_copy(fstate, "__arg2")

    sdfg.validate()

    return sdfg


def _make_sdfg():
    sdfg = dace.SDFG("where")
    state = sdfg.add_state()

    for name in "abcd":
        sdfg.add_array(
                name,
                shape=(10,),
                dtype=(dace.bool_ if name == "c" else dace.float64),
                storage=dace.dtypes.StorageType.GPU_Global,
                transient=False,
        )

    a, b, c, d = (state.add_access(name) for name in "abcd")
    me, mx = state.add_map("map", ndrange={"__i": "0:10"})
    nsdfg = state.add_nested_sdfg(
            sdfg=_make_nested_sdfg(),
            inputs={"__arg1", "__arg2", "__cond"},
            outputs={"__output"},
    )

    for ac, conn in [(a, "__arg1"), (b, "__arg2"), (c, "__cond")]:
        state.add_edge(
                ac,
                None,
                me,
                f"IN_{ac.data}",
                dace.Memlet(data=ac.data, subset="0:10"),
        )
        state.add_edge(
                me,
                f"OUT_{ac.data}",
                nsdfg,
                conn,
                dace.Memlet(data=ac.data, subset="__i"),
        )
        me.add_scope_connectors(ac.data)

    state.add_edge(
            nsdfg,
            "__output",
            mx,
            "IN_d",
            dace.Memlet("d[__i]"),
    )
    state.add_edge(
            mx,
            "OUT_d",
            d,
            None,
            dace.Memlet("d[0:10]"),
    )
    mx.add_scope_connectors("d")
    sdfg.apply_gpu_transformations()

    sdfg.validate()

    return sdfg


def main():
    sdfg = _make_sdfg()
    sdfg.compile()

if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions