Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions lmcache_ascend/v1/npu_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,31 @@ def _remote_batched_to_gpu(self, memory_objs, starts, ends, **kwargs):
prev_read_event = None
prev_batch = None

# Per-pool scatter-done events: prevent the next RDMA
# write into a pool from racing with a scatter that is
# still reading from the same pool on load_stream.
# Events are pre-allocated and re-recorded each iteration.
channel = proxy_items[0][0]._transfer_channel
transport_stream = getattr(
channel, "transport_stream", None
)
pool_scatter_events = [
torch.npu.Event(),
torch.npu.Event(),
]
pool_scatter_recorded = [False, False]

for batch_idx, batch in enumerate(micro_batches):
pool = pools[current_pool]

# Ensure the previous scatter from this pool has
# finished before RDMA overwrites the pool buffers.
if pool_scatter_recorded[current_pool] \
and transport_stream is not None:
transport_stream.wait_event(
pool_scatter_events[current_pool]
)

# Assign backing buffers from current pool to proxies
for i, (proxy, _, _) in enumerate(batch):
proxy.set_backing_obj(pool[i])
Expand All @@ -1000,9 +1022,10 @@ def _remote_batched_to_gpu(self, memory_objs, starts, ends, **kwargs):
prev_read_event,
**kwargs,
)
# TODO (gingfung): investigate whether
# we need to record scatter-done event on load_stream
# so the next RDMA into the same pool waits for it.
pool_scatter_events[1 - current_pool].record(
self.load_stream
)
pool_scatter_recorded[1 - current_pool] = True
self._clear_proxy_batch(prev_batch)

prev_read_event = cur_read_event
Expand Down
Loading