-
Notifications
You must be signed in to change notification settings - Fork 145
Open
Labels
bugSomething isn't workingSomething isn't working
Description
In the doitgen SDFG:
@dc.program
def kernel(A: dc.float64[NR, NQ, NP], C4: dc.float64[NP, NP]):
# Ideal - not working becayse Matmul with dim > 3 unsupported
# A[:] = np.reshape(np.reshape(A, (NR, NQ, 1, NP)) @ C4, (NR, NQ, NP))
for r in range(NR):
A[r, :, :] = np.reshape(np.reshape(A[r], (NQ, 1, NP)) @ C4, (NQ, NP))We got this. What I understand is that numpy runs a loop of GEMMs:
for i in NQ:
( 1, NP) @ (NP, NP) and
DaCe generates a GEMM call directly for the 3-dimensional slice:
Now, my question is: how should this be fixed? Should we generate a tensor contraction or a loop-over-GEMM?
This is one of the cases where squeezing the dimension would fix it, but then that has problems when we have other cases.
This fails, of course, when the MatMul node is expanded:
Traceback (most recent call last):
File "/home/primrose/Work/npbench/doitgen_repr.py", line 23, in <module>
sdfg.compile()
File "/home/primrose/Work/dace/dace/sdfg/sdfg.py", line 2418, in compile
program_objects = codegen.generate_code(sdfg, validate=validate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/primrose/Work/dace/dace/codegen/codegen.py", line 205, in generate_code
sdfg.expand_library_nodes()
File "/home/primrose/Work/dace/dace/sdfg/sdfg.py", line 2864, in expand_library_nodes
impl_name = node.expand(self, state)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/primrose/Work/dace/dace/sdfg/nodes.py", line 1584, in expand
transformation.apply(actual_state, sdfg, **expansion_kwargs)
File "/home/primrose/Work/dace/dace/transformation/transformation.py", line 707, in apply
expansion = type(self).expansion(node, state, sdfg, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/primrose/Work/dace/dace/libraries/blas/nodes/gemm.py", line 152, in expansion
node.validate(sdfg, state)
File "/home/primrose/Work/dace/dace/libraries/blas/nodes/gemm.py", line 1035, in validate
raise ValueError("matrix-matrix product only supported on matrices")
ValueError: matrix-matrix product only supported on matrices
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working