Skip to content

View-Interaction-With-Library-Node Bug in Python Frontend #2140

@ThrudPrimrose

Description

@ThrudPrimrose

In the doitgen SDFG:

@dc.program
def kernel(A: dc.float64[NR, NQ, NP], C4: dc.float64[NP, NP]):

    # Ideal - not working becayse Matmul with dim > 3 unsupported
    # A[:] = np.reshape(np.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP))
    for r in range(NR):
        A[r, :, :] = np.reshape(np.reshape(A[r], (NQ, 1, NP)) @ C4, (NQ, NP))

We got this. What I understand is that numpy runs a loop of GEMMs:

for i in NQ:
 ( 1, NP) @ (NP, NP) and

DaCe generates a GEMM call directly for the 3-dimensional slice:

Image

Now, my question is: how should this be fixed? Should we generate a tensor contraction or a loop-over-GEMM?
This is one of the cases where squeezing the dimension would fix it, but then that has problems when we have other cases.

This fails, of course, when the MatMul node is expanded:

Traceback (most recent call last):
  File "/home/primrose/Work/npbench/doitgen_repr.py", line 23, in <module>
    sdfg.compile()
  File "/home/primrose/Work/dace/dace/sdfg/sdfg.py", line 2418, in compile
    program_objects = codegen.generate_code(sdfg, validate=validate)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/primrose/Work/dace/dace/codegen/codegen.py", line 205, in generate_code
    sdfg.expand_library_nodes()
  File "/home/primrose/Work/dace/dace/sdfg/sdfg.py", line 2864, in expand_library_nodes
    impl_name = node.expand(self, state)
                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/primrose/Work/dace/dace/sdfg/nodes.py", line 1584, in expand
    transformation.apply(actual_state, sdfg, **expansion_kwargs)
  File "/home/primrose/Work/dace/dace/transformation/transformation.py", line 707, in apply
    expansion = type(self).expansion(node, state, sdfg, *args, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/primrose/Work/dace/dace/libraries/blas/nodes/gemm.py", line 152, in expansion
    node.validate(sdfg, state)
  File "/home/primrose/Work/dace/dace/libraries/blas/nodes/gemm.py", line 1035, in validate
    raise ValueError("matrix-matrix product only supported on matrices")
ValueError: matrix-matrix product only supported on matrices

Metadata

Metadata

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions