diff --git a/lmcache_ascend/v1/npu_connector.py b/lmcache_ascend/v1/npu_connector.py index 885b637..18165e8 100755 --- a/lmcache_ascend/v1/npu_connector.py +++ b/lmcache_ascend/v1/npu_connector.py @@ -885,6 +885,14 @@ def batched_to_gpu(self, memory_objs, starts, ends, **kwargs): assert not is_310p(), "Batched P2P transfer is not supported on 310P." self._remote_batched_to_gpu(memory_objs, starts, ends, **kwargs) + + # NOTE (gingfung): Ensure the compute stream waits for + # load_stream's KV scatter to complete before attention + # reads the same pages. + # load_stream.synchronize() in _remote_batched_to_gpu is + # host-side only, the compute stream has no knowledge of it + # and can race ahead. + torch.npu.current_stream().wait_stream(self.load_stream) else: with torch.cuda.stream(self.load_stream): for memory_obj, start, end in zip( @@ -979,9 +987,29 @@ def _remote_batched_to_gpu(self, memory_objs, starts, ends, **kwargs): prev_read_event = None prev_batch = None + # Per-pool scatter-done events: prevent the next RDMA + # write into a pool from racing with a scatter that is + # still reading from the same pool on load_stream. + # Events are pre-allocated and re-recorded each iteration. + channel = proxy_items[0][0]._transfer_channel + transport_stream = getattr(channel, "transport_stream", None) + pool_scatter_events = [ + torch.npu.Event(), + torch.npu.Event(), + ] + pool_scatter_recorded = [False, False] + for batch_idx, batch in enumerate(micro_batches): pool = pools[current_pool] + # Ensure the previous scatter from this pool has + # finished before RDMA overwrites the pool buffers. + if ( + pool_scatter_recorded[current_pool] + and transport_stream is not None + ): + transport_stream.wait_event(pool_scatter_events[current_pool]) + # Assign backing buffers from current pool to proxies for i, (proxy, _, _) in enumerate(batch): proxy.set_backing_obj(pool[i]) @@ -1000,9 +1028,8 @@ def _remote_batched_to_gpu(self, memory_objs, starts, ends, **kwargs): prev_read_event, **kwargs, ) - # TODO (gingfung): investigate whether - # we need to record scatter-done event on load_stream - # so the next RDMA into the same pool waits for it. + pool_scatter_events[1 - current_pool].record(self.load_stream) + pool_scatter_recorded[1 - current_pool] = True self._clear_proxy_batch(prev_batch) prev_read_event = cur_read_event