Skip to content

How to use graph.reshape() in Python #157

@imxKyrie

Description

@imxKyrie

When I use reshape in python with cudnn to build a computational graph, it always throws an error, not sure if it's because of my usage method. The following is the specific code, mainly used in create_matmul_graph with reshape.

import cudnn
import torch
import sys

handle = cudnn.create_handle()

input_type = torch.float16

# input tensors
# a = torch.randn(16, 128, 256, dtype=input_type, device="cuda")
a = torch.randn(16, 32768, dtype=input_type, device="cuda")
b = torch.randn(16, 256, 2048, dtype=input_type, device="cuda")

# reference output
c_ref = torch.matmul(a.reshape(16, 128, 256), b)

# place holder for cudnn output
c = torch.randn_like(c_ref, device="cuda")

def matmul_cache_key(handle, a, b):
    """Custom key function for matmul"""
    return (
        tuple(a.shape),
        tuple(b.shape),
        tuple(a.stride()),
        tuple(b.stride()),
        a.dtype,
        b.dtype,
    )
@cudnn.jit(heur_modes=[cudnn.heur_mode.A, cudnn.heur_mode.B])
@cudnn.graph_cache(key_fn=matmul_cache_key)
def create_matmul_graph(handle, a, b):
    with cudnn.graph(handle) as (g, _):
        a_cudnn = g.tensor_like(a)
        print(a_cudnn.get_dim(), a_cudnn.get_stride())
        a_cudnn = g.reshape(a_cudnn, name="reshape")
        print(a_cudnn.get_dim(), a_cudnn.get_stride())
        a_cudnn.set_dim([16, 128, 256]).set_stride([32768, 256, 1])
        print(a_cudnn.get_dim(), a_cudnn.get_stride())
        b_cudnn = g.tensor_like(b)
        c_cudnn = g.matmul(name="matmul", A=a_cudnn, B=b_cudnn)
        c_cudnn.set_output(True).set_data_type(cudnn.data_type.HALF)

    return g, [a_cudnn, b_cudnn, c_cudnn]  # Return raw graph and tensors

g, uids = create_matmul_graph(handle, a, b)

a_uid, b_uid, out_uid = uids

variant_pack = {
    a_uid: a,
    b_uid: b,
    out_uid: c,
}

workspace = torch.empty(g.get_workspace_size(), device="cuda", dtype=torch.uint8)
g.execute(variant_pack, workspace)
torch.cuda.synchronize()
torch.testing.assert_close(c, c_ref, rtol=5e-3, atol=5e-3)

The error message is as follows:

Traceback (most recent call last):
  File "/opt/tiger/efficientvit/test.py", line 46, in <module>
    g, uids = create_matmul_graph(handle, a, b)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.local/lib/python3.11/site-packages/cudnn/graph.py", line 68, in wrapper
    g.build(heur_modes)  # Build the graph
    ^^^^^^^^^^^^^^^^^^^
cudnn._compiled_module.cudnnGraphNotSupportedError: No valid engine configs for Reshape_Matmul_
{"engineId":0,"smVersion":890,"knobChoices":{}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":1}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":21}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":10}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":32}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":25}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":22}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":24}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":16}}

{"engineId":7,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":12}}

{"engineId":1,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":27}}

{"engineId":1,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":11}}

{"engineId":1,"smVersion":890,"knobChoices":{"CUDNN_KNOB_TYPE_SPLIT_K_SLC":-1,"CUDNN_KNOB_TYPE_KERNEL_CFG":10}}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions