Skip to content

Commit 9e732d4

Browse files
Fix mypy error in cutlass_gemm example
1 parent 33f14ec commit 9e732d4

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

python/examples/cutlass_gemm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ def as_core_Stream(cs: nvbench.CudaStream) -> core.Stream:
3434
return core.Stream.from_handle(cs.addressof())
3535

3636

37-
def make_cp_array(arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int) -> cp.ndarray:
37+
def make_cp_array(
38+
arr_h: np.ndarray, dev_buf: core.Buffer, dev_id: int | None
39+
) -> cp.ndarray:
3840
cp_memview = cp.cuda.UnownedMemory(
39-
int(dev_buf.handle), dev_buf.size, dev_buf, dev_id
41+
int(dev_buf.handle), dev_buf.size, dev_buf, -1 if dev_id is None else dev_id
4042
)
4143
zero_offset = 0
4244
return cp.ndarray(

0 commit comments

Comments
 (0)