-
Notifications
You must be signed in to change notification settings - Fork 135
Open
Description
When I use reshape in python with cudnn to build a computational graph, it always throws an error, not sure if it's because of my usage method. The following is the specific code, mainly used in create_matmul_graph with reshape.
import cudnn
import torch
import sys
handle = cudnn.create_handle()
input_type = torch.float16
# input tensors
# a = torch.randn(16, 128, 256, dtype=input_type, device="cuda")
a = torch.randn(16, 32768, dtype=input_type, device="cuda")
b = torch.randn(16, 256, 2048, dtype=input_type, device="cuda")
# reference output
c_ref = torch.matmul(a.reshape(16, 128, 256), b)
# place holder for cudnn output
c = torch.randn_like(c_ref, device="cuda")
def matmul_cache_key(handle, a, b):
"""Custom key function for matmul"""
return (
tuple(a.shape),
tuple(b.shape),
tuple(a.stride()),
tuple(b.stride()),
a.dtype,
b.dtype,
)
@cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.B])
@cudnn.graph_cache(key_fn=matmul_cache_key)
def create_matmul_graph(handle, a, b):
with cudnn.graph(handle) as (g, _):
a_cudnn = g.tensor_like(a)
print(a_cudnn.get_dim(), a_cudnn.get_stride())
a_cudnn = g.reshape(a_cudnn, name="reshape")
print(a_cudnn.get_dim(), a_cudnn.get_stride())
a_cudnn.set_dim([16, 128, 256]).set_stride([32768, 256, 1])
print(a_cudnn.get_dim(), a_cudnn.get_stride())
b_cudnn = g.tensor_like(b)
c_cudnn = g.matmul(name="matmul", A=a_cudnn, B=b_cudnn)
c_cudnn.set_output(True).set_data_type(cudnn.data_type.HALF)
return g, [a_cudnn, b_cudnn, c_cudnn] # Return raw graph and tensors
g, uids = create_matmul_graph(handle, a, b)
a_uid, b_uid, out_uid = uids
variant_pack = {
a_uid: a,
b_uid: b,
out_uid: c,
}
workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8)
g.execute(variant_pack, workspace)
torch.cuda.synchronize()
torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)
The error message is as follows:
Traceback (most recent call last):
File "/opt/tiger/efficientvit/test.py", line 46, in <module>
g, uids = create_matmul_graph(handle, a, b)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.local/lib/python3.11/site-packages/cudnn/graph.py", line 68, in wrapper
g.build(heur_modes) # Build the graph
^^^^^^^^^^^^^^^^^^^
cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Reshape_Matmul_
{"engineId":0,"smVersion":890,"knobChoices":{}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":1}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":21}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":10}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":32}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":25}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":22}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":24}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":16}}
{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":12}}
{"engineId":1,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":27}}
{"engineId":1,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":11}}
{"engineId":1,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":10}}
Metadata
Metadata
Assignees
Labels
No labels