Skip to content

Commit 8d6e431

Browse files
Glitchfixshi-eric
authored andcommitted
fix(tile): reassign shared tiles to register [GH-1440]
* fix(tile): reassign shared tiles to register Generated CUDA tile kernels can emit wp::assign() when an existing register-backed tile variable is reassigned from a same-shape shared-backed tile. Add the missing assignment and adjoint paths so the forward copy succeeds, source gradients accumulate into the shared tile, and overwritten register destination gradients are cleared. (GH-1440) Signed-off-by: Shivanjan Chakravorty <shivanjanc@nvidia.com> Approved-by: Zach Corse <zcorse@nvidia.com> See merge request omniverse/warp!2367
1 parent 51b9dc3 commit 8d6e431

3 files changed

Lines changed: 88 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@
125125
now well-defined (and gradient-stable) for near-parallel edges, where it
126126
could previously return geometrically valid but unstable barycentric
127127
weights ([GH-1437](https://github.com/NVIDIA/warp/issues/1437)).
128+
- Fix tile reassignment from shared storage to register storage
129+
([GH-1440](https://github.com/NVIDIA/warp/issues/1440)).
128130

129131
### Documentation
130132

warp/native/tile.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5602,6 +5602,34 @@ inline CUDA_CALLABLE void adj_assign(
56025602
WP_TILE_SYNC();
56035603
}
56045604

5605+
template <typename T, typename SharedLayout, bool Owner>
5606+
inline CUDA_CALLABLE void assign(
5607+
tile_register_t<T, tile_layout_register_t<typename SharedLayout::Shape>>& dest,
5608+
const tile_shared_t<T, SharedLayout, Owner>& src
5609+
)
5610+
{
5611+
dest.assign(src.copy_to_register());
5612+
}
5613+
5614+
template <typename T, typename SharedLayout, bool Owner>
5615+
inline CUDA_CALLABLE void adj_assign(
5616+
tile_register_t<T, tile_layout_register_t<typename SharedLayout::Shape>>& dest,
5617+
const tile_shared_t<T, SharedLayout, Owner>& src,
5618+
tile_register_t<T, tile_layout_register_t<typename SharedLayout::Shape>>& adj_dest,
5619+
tile_shared_t<T, SharedLayout, Owner>& adj_src
5620+
)
5621+
{
5622+
(void)dest;
5623+
(void)src;
5624+
5625+
if (adj_src.grad.ptr != nullptr) {
5626+
adj_src.grad_add(adj_dest);
5627+
}
5628+
5629+
// Overwritten destinations do not contribute to the pre-assignment dest value.
5630+
adj_dest.zero();
5631+
}
5632+
56055633

56065634
template <typename TileA, typename Scalar> inline CUDA_CALLABLE void assign(TileA& dest, int i, const Scalar& src)
56075635
{

warp/tests/tile/test_tile_shared_memory.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,58 @@ def compute(
600600
np.testing.assert_allclose(out.numpy(), inp_np)
601601

602602

603+
def test_tile_register_from_shared_reassign(test, device):
604+
TILE_SIZE = 8
605+
BLOCK_DIM = 64
606+
607+
@wp.kernel(module="unique")
608+
def compute(
609+
src: wp.array[float],
610+
overwritten: wp.array[float],
611+
reassigned: wp.array[float],
612+
direct: wp.array[float],
613+
iters: int,
614+
):
615+
t = wp.tile_load(overwritten, shape=TILE_SIZE, offset=0, storage="register")
616+
s = wp.tile_load(src, shape=TILE_SIZE, offset=0, storage="shared")
617+
618+
for _ in range(iters):
619+
t = s
620+
621+
wp.tile_store(reassigned, t, offset=0)
622+
wp.tile_store(direct, s, offset=0)
623+
624+
src_np = np.arange(TILE_SIZE, dtype=np.float32) + 1.0
625+
overwritten_np = np.arange(TILE_SIZE, dtype=np.float32) + 101.0
626+
627+
src = wp.array(src_np, requires_grad=True, device=device)
628+
overwritten = wp.array(overwritten_np, requires_grad=True, device=device)
629+
reassigned = wp.zeros(TILE_SIZE, dtype=float, requires_grad=True, device=device)
630+
direct = wp.zeros(TILE_SIZE, dtype=float, requires_grad=True, device=device)
631+
632+
with wp.Tape() as tape:
633+
wp.launch_tiled(
634+
compute,
635+
dim=[1],
636+
inputs=[src, overwritten, reassigned, direct, 2],
637+
block_dim=BLOCK_DIM,
638+
device=device,
639+
)
640+
641+
np.testing.assert_allclose(reassigned.numpy(), src_np)
642+
np.testing.assert_allclose(direct.numpy(), src_np)
643+
644+
tape.backward(
645+
grads={
646+
reassigned: wp.ones_like(reassigned, device=device),
647+
direct: wp.ones_like(direct, device=device),
648+
}
649+
)
650+
651+
np.testing.assert_allclose(src.grad.numpy(), np.full(TILE_SIZE, 2.0, dtype=np.float32))
652+
np.testing.assert_allclose(overwritten.grad.numpy(), np.zeros(TILE_SIZE, dtype=np.float32))
653+
654+
603655
def test_tile_scatter_masked_basic(test, device):
604656
"""Each thread writes its index; verify all values are visible after the call."""
605657
TILE_SIZE = 64
@@ -889,6 +941,12 @@ class TestTileSharedMemory(unittest.TestCase):
889941
test_tile_shared_coalesced_mat44,
890942
devices=devices,
891943
)
944+
add_function_test(
945+
TestTileSharedMemory,
946+
"test_tile_register_from_shared_reassign",
947+
test_tile_register_from_shared_reassign,
948+
devices=devices,
949+
)
892950
add_function_test(
893951
TestTileSharedMemory, "test_tile_scatter_masked_basic", test_tile_scatter_masked_basic, devices=devices
894952
)

0 commit comments

Comments
 (0)