diff --git a/third_party/amd/language/hip/librocshmem_device.py b/third_party/amd/language/hip/librocshmem_device.py index 71bcceb81499..eae152d5b721 100644 --- a/third_party/amd/language/hip/librocshmem_device.py +++ b/third_party/amd/language/hip/librocshmem_device.py @@ -50,7 +50,7 @@ def set_rocshmem_ctx(ctx, _semantic=None): tl.cast(ctx, tl.pointer_type(tl.void), _semantic=_semantic), ], { - (tl.pointer_type(tl.void),): ("rocshmem_set_ctx", ()), + (tl.pointer_type(tl.void),): ("rocshmem_set_rocshmem_ctx", ()), }, is_pure=False, _semantic=_semantic, @@ -480,30 +480,50 @@ def putmem_signal_nbi_wave(dest, ) +@core.extern +def signal_wait_until(sig_addr, cmp_, cmp_val, _semantic=None): + tl.static_assert(sig_addr.dtype == pi_u64_t or sig_addr.dtype == pi_i64_t, + "sig_addr should be a pointer of uint64_t/int64_t", + _semantic=_semantic) + return extern_call( + "librocshmem_device", + "", + [ + tl.cast(sig_addr, pi_u64_t, _semantic=_semantic), + tl.cast(cmp_, tl.int32, _semantic=_semantic), + tl.cast(cmp_val, tl.uint64, _semantic=_semantic), + ], # no cast + { + (pi_u64_t, tl.int32, tl.uint64): ( + "rocshmem_ulong_wait_until_wrapper", + (), + ), + }, + is_pure=False, + _semantic=_semantic, + ) -# @core.extern -# def wait_until(sig_addr, cmp_, cmp_val, _semantic=None): -# tl.static_assert(sig_addr.dtype == pi_u64_t or sig_addr.dtype == pi_i64_t, -# "sig_addr should be a pointer of uint64_t/int64_t", -# _semantic=_semantic) -# return extern_call( -# "librocshmem_device", -# "", -# [ -# tl.cast(sig_addr, pi_u64_t, _semantic=_semantic), -# tl.cast(cmp_, tl.int32, _semantic=_semantic), -# tl.cast(cmp_val, tl.uint64, _semantic=_semantic), -# ], # no cast -# { -# (pi_u64_t, tl.int32, tl.uint64): ( -# "rocshmem_wait_until_wrapper", -# (), -# ), -# }, -# is_pure=False, -# _semantic=_semantic, -# ) +@core.extern +def ulong_put_signal(dest, source, nelems, sig_addr, signal, sig_op, pe, _semantic=None): + return extern_call( + "librocshmem_device", + "", + [ + tl.cast(dest, pi_u64_t, _semantic=_semantic), + tl.cast(source, pi_u64_t, _semantic=_semantic), + tl.cast(nelems, tl.uint64, _semantic=_semantic), + tl.cast(sig_addr, pi_u64_t, _semantic=_semantic), + tl.cast(signal, tl.uint64, _semantic=_semantic), + tl.cast(sig_op, tl.int32, _semantic=_semantic), + tl.cast(pe, tl.int32, _semantic=_semantic), + ], + { + (pi_u64_t, pi_u64_t, tl.uint64, pi_u64_t, tl.uint64, tl.int32, tl.int32): ( + "rocshmem_ulong_put_signal_wrapper", + (), + ), + }, is_pure=False, _semantic=_semantic) @core.extern