-
Notifications
You must be signed in to change notification settings - Fork 81
[Feat] use host-pinned memory with dual CPU/device addresses for transport buffers #1024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feature_26h1
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -119,12 +119,15 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers( | |
| continue; | ||
| } | ||
|
|
||
| if (subBatchContext.flagBuffer.addr == 0 || subBatchContext.flagBuffer.length == 0) { | ||
| if (subBatchContext.sendSge.device_addr == 0 || subBatchContext.flagBuffer.addr == 0 || | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Critical] Missing validation for sendSge.length The validation now checks sendSge.device_addr, flagBuffer.addr, flagBuffer.device_addr, and flagBuffer.length, but sendSge.length is not validated. If sendSge.length is 0 or invalid, the SendIoBatch will have an incorrect length parameter, potentially causing buffer overflow or underflow in RDMA operations. Add: sendSge.length == 0 to the validation condition. |
||
| subBatchContext.flagBuffer.device_addr == 0 || subBatchContext.flagBuffer.length == 0) { | ||
| const auto status = | ||
| Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch flag buffer is not ready"); | ||
| Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch transport buffers are not ready"); | ||
| UC_ERROR( | ||
| "Sub-batch flag buffer is not ready index={} cid={} flag_addr={} flag_length={}", | ||
| index, subBatchContext.cid, subBatchContext.flagBuffer.addr, | ||
| "Sub-batch transport buffers are not ready index={} cid={} send_device_addr={} " | ||
| "flag_addr={} flag_device_addr={} flag_length={}", | ||
| index, subBatchContext.cid, subBatchContext.sendSge.device_addr, | ||
| subBatchContext.flagBuffer.addr, subBatchContext.flagBuffer.device_addr, | ||
| subBatchContext.flagBuffer.length); | ||
| SetSubBatchSendFailed(subBatchContext, status); | ||
| if (finalStatus.ok()) { finalStatus = status; } | ||
|
|
@@ -134,8 +137,8 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers( | |
|
|
||
| ioBatches.push_back( | ||
| TransProvider::SendIoBatch{subBatchContext.channel->GetConnection(), | ||
| reinterpret_cast<void*>(subBatchContext.sendSge.addr), | ||
| reinterpret_cast<void*>(subBatchContext.flagBuffer.addr), | ||
| reinterpret_cast<void*>(subBatchContext.sendSge.device_addr), | ||
| reinterpret_cast<void*>(subBatchContext.flagBuffer.device_addr), | ||
| subBatchContext.sendSge.length}); | ||
| subBatchIndexes.emplace_back(index); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,79 @@ | |
| #include "trans/ascend/ascend_buffer.h" | ||
|
|
||
| namespace UC::ASU { | ||
| namespace { | ||
|
|
||
| struct BufferRegion { | ||
| std::shared_ptr<void> owner; | ||
| void* localAddr{nullptr}; | ||
| void* deviceAddr{nullptr}; | ||
| TransProvider::MemType providerMemType{TransProvider::MemType::MEM_HOST}; | ||
| }; | ||
|
|
||
| class BufferRegionCreator : public Trans::AscendBuffer { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Naming problem, the parent class is a Buffer and the child class is a Creator? |
||
| public: | ||
| Status MakeRegion(MemoryType type, std::size_t size, BufferRegion& region) | ||
| { | ||
| switch (type) { | ||
| case MemoryType::HOST: { | ||
| auto owner = MakeHostBuffer(size); | ||
| if (!owner) { return AllocationFailed("host"); } | ||
| region = {owner, owner.get(), owner.get(), TransProvider::MemType::MEM_HOST}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Suggestion] device_addr semantic inconsistency for HOST memory type For MemoryType::HOST, deviceAddr is set equal to localAddr (owner.get()). This is semantically ambiguous because plain host memory does not have a device-visible mapping. The comment in buffer_manager.h states device_addr is for device-visible address used by HCOMM/RDMA, but setting it equal to addr for HOST type contradicts this. Consider:
|
||
| return Status::OK(); | ||
| } | ||
| case MemoryType::HOST_PINNED: return MakeHostPinnedBuffer(size, region); | ||
| case MemoryType::ASCEND_DEVICE: { | ||
| auto owner = MakeDeviceBuffer(size); | ||
| if (!owner) { return AllocationFailed("device"); } | ||
| region = {owner, owner.get(), owner.get(), TransProvider::MemType::MEM_DEVICE}; | ||
| return Status::OK(); | ||
| } | ||
| default: | ||
| return Status::Error(StatusCode::INVALID_ARGUMENT, "unsupported memory type"); | ||
| } | ||
| } | ||
|
|
||
| private: | ||
| static Status AllocationFailed(const char* type) | ||
| { | ||
| return Status::Error(StatusCode::INTERNAL_ERROR, | ||
| std::string("failed to allocate ") + type + " memory"); | ||
| } | ||
|
|
||
| Status MakeHostPinnedBuffer(std::size_t size, BufferRegion& region) | ||
| { | ||
| void* hostAddr = nullptr; | ||
| auto ret = aclrtMallocHost(&hostAddr, size); | ||
| if (ret != ACL_SUCCESS) { return AllocationFailed("host-pinned"); } | ||
|
|
||
| ret = aclrtHostRegisterV2(hostAddr, size, ACL_HOST_REG_MAPPED | ACL_HOST_REG_PINNED); | ||
| if (ret != ACL_SUCCESS) { | ||
| aclrtFreeHost(hostAddr); | ||
| return Status::Error(StatusCode::INTERNAL_ERROR, | ||
| "failed to register host-pinned memory with ACL"); | ||
| } | ||
|
|
||
| void* deviceAddr = nullptr; | ||
| ret = aclrtHostGetDevicePointer(hostAddr, &deviceAddr, 0); | ||
| if (ret != ACL_SUCCESS) { | ||
| aclrtHostUnregister(hostAddr); | ||
| aclrtFreeHost(hostAddr); | ||
| return Status::Error(StatusCode::INTERNAL_ERROR, | ||
| "failed to get host-pinned device address"); | ||
| } | ||
|
|
||
| // The owner keeps the ACL registration alive until after HCOMM has | ||
| // unregistered the region in BufferManager's destructor. | ||
| auto owner = std::shared_ptr<void>(hostAddr, [](void* addr) { | ||
| aclrtHostUnregister(addr); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Warning] Custom deleter lacks error handling The custom deleter calls aclrtHostUnregister(addr) and aclrtFreeHost(addr) without checking return values. If aclrtHostUnregister fails, calling aclrtFreeHost on still-registered memory may cause undefined behavior or resource leaks. Consider:
|
||
| aclrtFreeHost(addr); | ||
| }); | ||
| region = {owner, hostAddr, deviceAddr, TransProvider::MemType::MEM_DEVICE}; | ||
| return Status::OK(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| BufferManager::~BufferManager() | ||
| { | ||
|
|
@@ -38,6 +111,7 @@ BufferManager::~BufferManager() | |
| provider_->UnregisterMemory(descs); | ||
| } | ||
| memory_.reset(); | ||
| device_memory_ = nullptr; | ||
| slot_size_ = 0; | ||
| slot_num_ = 0; | ||
| } | ||
|
|
@@ -60,22 +134,18 @@ Status BufferManager::Init(std::string name, MemoryType type, std::size_t slot_s | |
|
|
||
| std::size_t total = slot_size * slot_num; | ||
|
|
||
| Trans::AscendBuffer allocator; | ||
| switch (memory_type_) { | ||
| case MemoryType::HOST: memory_ = allocator.MakeHostBuffer(total); break; | ||
| case MemoryType::HOST_PINNED: memory_ = allocator.MakeHostBuffer4DirectIo(total); break; | ||
| case MemoryType::ASCEND_DEVICE: memory_ = allocator.MakeDeviceBuffer(total); break; | ||
| default: | ||
| return Status::Error(StatusCode::INVALID_ARGUMENT, name_ + ": unsupported memory type"); | ||
| } | ||
|
|
||
| if (!memory_) { | ||
| return Status::Error(StatusCode::INTERNAL_ERROR, name_ + ": failed to allocate memory"); | ||
| } | ||
| BufferRegionCreator regionCreator; | ||
| BufferRegion region; | ||
| auto allocStatus = regionCreator.MakeRegion(memory_type_, total, region); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Violating the inherit spirit, the child class has its own public function and will be called directly. |
||
| if (!allocStatus.ok()) { return allocStatus; } | ||
| memory_ = std::move(region.owner); | ||
| device_memory_ = region.deviceAddr; | ||
| provider_mem_type_ = region.providerMemType; | ||
|
|
||
| if (memory_type_ == MemoryType::ASCEND_DEVICE) { | ||
| if (aclrtMemset(memory_.get(), total, 0, total) != ACL_SUCCESS) { | ||
| memory_.reset(); | ||
| device_memory_ = nullptr; | ||
| return Status::Error(StatusCode::INTERNAL_ERROR, | ||
| name_ + ": failed to zero device memory"); | ||
| } | ||
|
|
@@ -91,6 +161,7 @@ Status BufferManager::Init(std::string name, MemoryType type, std::size_t slot_s | |
| if (!regStatus.ok()) { | ||
| provider_ = nullptr; | ||
| memory_.reset(); | ||
| device_memory_ = nullptr; | ||
| return regStatus; | ||
| } | ||
| } | ||
|
|
@@ -100,11 +171,9 @@ Status BufferManager::Init(std::string name, MemoryType type, std::size_t slot_s | |
|
|
||
| Status BufferManager::RegisterMemory() | ||
| { | ||
| auto memType = (memory_type_ == MemoryType::ASCEND_DEVICE) ? TransProvider::MemType::MEM_DEVICE | ||
| : TransProvider::MemType::MEM_HOST; | ||
| std::size_t total = slot_size_ * slot_num_; | ||
| std::vector<TransProvider::RegisterMemoryDesc> descs{ | ||
| {memType, reinterpret_cast<uintptr_t>(memory_.get()), total} | ||
| {provider_mem_type_, reinterpret_cast<uintptr_t>(device_memory_), total} | ||
| }; | ||
| std::vector<TransProvider::MemHandle> memHandles; | ||
| auto regStatus = provider_->RegisterMemory(nullptr, descs, memHandles); | ||
|
|
@@ -141,8 +210,10 @@ Status BufferManager::Allocate(std::size_t size, ScatterGatherEntry& sge) | |
| if (idx == IndexPool::npos) { | ||
| return Status::Error(StatusCode::RESOURCE_BUSY, name_ + ": no free slots"); | ||
| } | ||
| void* addr = static_cast<char*>(memory_.get()) + idx * slot_size_; | ||
| sge.addr = reinterpret_cast<std::uint64_t>(addr); | ||
| const auto offset = idx * slot_size_; | ||
| sge.addr = reinterpret_cast<std::uint64_t>(static_cast<char*>(memory_.get()) + offset); | ||
| sge.device_addr = | ||
| reinterpret_cast<std::uint64_t>(static_cast<char*>(device_memory_) + offset); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Warning] Potential null pointer arithmetic if device_memory_ is null The device_addr calculation uses static_cast<char*>(device_memory_) + offset. If device_memory_ is nullptr (e.g., due to a failed Init that did not properly set it), this pointer arithmetic produces a garbage address value. While Allocate() checks memory_ is valid at line 201, it does not verify device_memory_ is non-null. Consider adding: if (!device_memory_) return Status::Error(...) |
||
| sge.length = static_cast<std::uint32_t>(size); | ||
| sge.tokenId = tokenId_; | ||
| sge.slot_index = idx; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,10 @@ | |
| namespace UC::ASU { | ||
|
|
||
| struct ScatterGatherEntry { | ||
| // Local address used by CPU code for SQE packing and completion polling. | ||
| std::uint64_t addr{0}; | ||
| // Device-visible address used by HCOMM/HIXL and remote RDMA operations. | ||
| std::uint64_t device_addr{0}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Performance] Struct size increase may impact cache efficiency Adding device_addr increases ScatterGatherEntry from ~24 bytes to ~32 bytes (assuming 64-bit system). In high-throughput scenarios with many pending SQEs, this 33% size increase could:
This is acceptable given the feature requirement, but worth noting for performance-sensitive deployments. |
||
| std::uint32_t length{0}; | ||
| std::uint32_t tokenId{0}; | ||
| std::uint32_t slot_index{UINT32_MAX}; | ||
|
|
@@ -66,6 +69,8 @@ class BufferManager { | |
| MemoryType memory_type_{MemoryType::HOST}; | ||
|
|
||
| std::shared_ptr<void> memory_; | ||
| void* device_memory_{nullptr}; | ||
| TransProvider::MemType provider_mem_type_{TransProvider::MemType::MEM_HOST}; | ||
| IndexPool index_pool_; | ||
|
|
||
| TransProvider* provider_{nullptr}; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -199,7 +199,7 @@ KvBatchStoreRequest BuildBatchStoreRequest( | |
| request.kv_ns_id = GetTransportConfigAttr<std::uint32_t>(attrs, "kv_ns_id"); | ||
| request.dtype = GetTransportConfigAttr<std::uint8_t>(attrs, "dtype"); | ||
| request.dspec = GetTransportConfigAttr<std::uint8_t>(attrs, "dspec"); | ||
| request.response_buffer_addr = flagBuffer.addr; | ||
| request.response_buffer_addr = flagBuffer.device_addr; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Suggestion] Architectural consistency for user-provided buffer addresses The response_buffer_addr now correctly uses flagBuffer.device_addr for RDMA operations. However, the entry.buffer_addr at line 212 uses entries[index].buffer.region.addr (CPU address). If user-provided KVBuffer regions are also host-pinned memory, they should similarly use device addresses for RDMA operations. This creates a semantic inconsistency: internal buffers use device_addr for RDMA, but user buffers use addr. Consider:
|
||
| request.response_mr_key = flagBuffer.tokenId; | ||
| request.lr = GetTransportConfigAttr<bool>(attrs, "lr"); | ||
| request.rflag = true; | ||
|
|
@@ -224,7 +224,7 @@ KvBatchRetrieveRequest BuildBatchRetrieveRequest( | |
| KvBatchRetrieveRequest request; | ||
| request.cid = cid; | ||
| request.kv_ns_id = GetTransportConfigAttr<std::uint32_t>(attrs, "kv_ns_id"); | ||
| request.response_buffer_addr = flagBuffer.addr; | ||
| request.response_buffer_addr = flagBuffer.device_addr; | ||
| request.response_mr_key = flagBuffer.tokenId; | ||
| request.lr = GetTransportConfigAttr<bool>(attrs, "lr"); | ||
| request.rflag = true; | ||
|
|
@@ -259,7 +259,7 @@ KvDeleteRequest BuildDeleteRequest(const BatchView<CacheKey>& keys, | |
| KvDeleteRequest request; | ||
| request.cid = cid; | ||
| request.kv_ns_id = GetTransportConfigAttr<std::uint32_t>(attrs, "kv_ns_id"); | ||
| request.response_buffer_addr = flagBuffer.addr; | ||
| request.response_buffer_addr = flagBuffer.device_addr; | ||
| request.response_mr_key = flagBuffer.tokenId; | ||
| request.rflag = true; | ||
| request.keys = CopyKeys(keys); | ||
|
|
@@ -274,7 +274,7 @@ KvExistRequest BuildExistRequest(const BatchView<CacheKey>& keys, | |
| KvExistRequest request; | ||
| request.cid = cid; | ||
| request.kv_ns_id = GetTransportConfigAttr<std::uint32_t>(attrs, "kv_ns_id"); | ||
| request.response_buffer_addr = flagBuffer.addr; | ||
| request.response_buffer_addr = flagBuffer.device_addr; | ||
| request.response_mr_key = flagBuffer.tokenId; | ||
| request.rflag = true; | ||
| request.sc = GetTransportConfigAttr<bool>(attrs, "sc"); | ||
|
|
@@ -287,7 +287,7 @@ KvKeepAliveRequest BuildKeepAliveRequest(std::uint16_t cid, const ScatterGatherE | |
| { | ||
| KvKeepAliveRequest request; | ||
| request.cid = cid; | ||
| request.response_buffer_addr = flagBuffer.addr; | ||
| request.response_buffer_addr = flagBuffer.device_addr; | ||
| request.response_mr_key = flagBuffer.tokenId; | ||
| request.rflag = true; | ||
| return request; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about sendSge.addr and sendSge.length?