Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 98 additions & 26 deletions csrc/common/managed_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
#endif
#include "driver/ascend_hal.h"
#include "driver/ascend_hal_define.h"
#include <cstring>
#include <dlfcn.h>
#include <errno.h>
#include <iostream>
#include <stdexcept>
#include <string>
#include <sys/mman.h>

Expand All @@ -31,16 +34,16 @@ HostRegisteredMemoryManager::~HostRegisteredMemoryManager() {
};

void HostRegisteredMemoryManager::unregisterAll() {
const std::unique_lock<std::shared_mutex> guard(this->mux);
const std::unique_lock<std::shared_mutex> guard(this->regMux);

// Iterate through each key-value pair in the map.
for (const auto &pair : this->allocatedMap) {
// Iterate through each key-value pair in the registeredMap.
for (const auto &pair : this->registeredMap) {
void *hostPtr = pair.first;
aclrtHostUnregister(hostPtr);
}

// After unregistering all pointers, clear the map completely.
this->allocatedMap.clear();
this->registeredMap.clear();
};

// Register a pointer through high level APIs (aclrt) return devPtr
Expand All @@ -51,11 +54,11 @@ RegisteredMemoryRecord *HostRegisteredMemoryManager::registerHostPtr(
LMCACHE_ASCEND_CHECK(
!(hostPtr == nullptr || bufferSize == 0),
"Error: hostPtr cannot be null and bufferSize must be greater than 0.");
const std::unique_lock<std::shared_mutex> guard(this->mux);
const std::unique_lock<std::shared_mutex> guard(this->regMux);

// Check if the host pointer is already registered
if (this->allocatedMap.count(hostPtr)) {
return &this->allocatedMap[hostPtr];
if (this->registeredMap.count(hostPtr)) {
return &this->registeredMap[hostPtr];
}

void *devPtr;
Expand All @@ -67,12 +70,12 @@ RegisteredMemoryRecord *HostRegisteredMemoryManager::registerHostPtr(
return nullptr;
}

this->allocatedMap.emplace(
this->registeredMap.emplace(
hostPtr, RegisteredMemoryRecord{reinterpret_cast<uintptr_t>(hostPtr),
reinterpret_cast<uintptr_t>(devPtr),
bufferSize, -1});

return &this->allocatedMap[hostPtr];
return &this->registeredMap[hostPtr];
};

// Register an existing host-device pointer mapping to the memory manager
Expand All @@ -84,19 +87,19 @@ HostRegisteredMemoryManager::registerMappedMem(void *hostPtr, void *devPtr,
!(hostPtr == nullptr || devPtr == nullptr || bufferSize == 0),
"Error: hostPtr and devPtr cannot be null and bufferSize must be greater "
"than 0.");
const std::unique_lock<std::shared_mutex> guard(this->mux);
const std::unique_lock<std::shared_mutex> guard(this->regMux);

// Check if the host pointer is already registered
LMCACHE_ASCEND_CHECK(
!(this->allocatedMap.count(hostPtr)),
!(this->registeredMap.count(hostPtr)),
"Error: hostPtr already registered to host memory manager.");

this->allocatedMap.emplace(
this->registeredMap.emplace(
hostPtr, RegisteredMemoryRecord{reinterpret_cast<uintptr_t>(hostPtr),
reinterpret_cast<uintptr_t>(devPtr),
bufferSize, -1});

return &this->allocatedMap[hostPtr];
return &this->registeredMap[hostPtr];
};

// Register a pointer through low level APIs (HAL). Allocates a new pinned host
Expand All @@ -109,7 +112,7 @@ HostRegisteredMemoryManager::halRegisterHostPtr(void *hostPtr,
// Essentially, the halHostRegister function requires a ptr given by mmap.
LMCACHE_ASCEND_CHECK((bufferSize >= 0),
"Error: bufferSize must be greater than 0.");
const std::unique_lock<std::shared_mutex> guard(this->mux);
const std::unique_lock<std::shared_mutex> guard(this->regMux);

void *devPtr;
int device = get_device();
Expand Down Expand Up @@ -140,44 +143,102 @@ HostRegisteredMemoryManager::halRegisterHostPtr(void *hostPtr,
std::to_string(lockErr))
}

this->allocatedMap.emplace(
this->registeredMap.emplace(
hostPtr,
RegisteredMemoryRecord{reinterpret_cast<uintptr_t>(hostPtr),
reinterpret_cast<uintptr_t>(devPtr), bufferSize,
static_cast<int32_t>(device)});

return &this->allocatedMap[hostPtr];
return &this->registeredMap[hostPtr];
};

int HostRegisteredMemoryManager::aclUnregisterHostPtr(void *hostPtr) {
LMCACHE_ASCEND_CHECK(hostPtr != nullptr, "Error: hostPtr cannot be null.");

// we don't actually mind if it doesn't unregister,
// at context destroy it should be unregister anyway.
const std::unique_lock<std::shared_mutex> guard(this->mux);
if (this->allocatedMap.count(hostPtr) == 0) {
const std::unique_lock<std::shared_mutex> guard(this->regMux);
if (this->registeredMap.count(hostPtr) == 0) {
// we probably did not register anyway
return 0;
}
aclError err = aclrtHostUnregister(hostPtr);
this->allocatedMap.erase(hostPtr);
this->registeredMap.erase(hostPtr);
return static_cast<int>(err);
};

int HostRegisteredMemoryManager::halUnregisterHostPtr(void *hostPtr) {
LMCACHE_ASCEND_CHECK(hostPtr != nullptr, "Error: hostPtr cannot be null.");
const std::unique_lock<std::shared_mutex> guard(this->mux);
if (this->allocatedMap.count(hostPtr) == 0) {
const std::unique_lock<std::shared_mutex> guard(this->regMux);
if (this->registeredMap.count(hostPtr) == 0) {
// we probably did not register anyway
return 0;
}
auto record = this->allocatedMap[hostPtr];
auto record = this->registeredMap[hostPtr];
auto err = halHostUnregisterEx(reinterpret_cast<void *>(hostPtr),
static_cast<UINT32>(record.device),
HOST_MEM_MAP_DEV_PCIE_TH);
return static_cast<int>(err);
}

// Track a memory allocation - allocate and lock memory
AllocatedMemoryRecord *HostRegisteredMemoryManager::allocMem(size_t size) {
LMCACHE_ASCEND_CHECK(size > 0, "Error: size must be greater than 0.");
const std::unique_lock<std::shared_mutex> guard(this->allocMux);

// Allocate pinned memory using mmap
void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (ptr == MAP_FAILED) {
throw std::runtime_error(std::string("[allocMem] mmap failed: ") +
strerror(errno));
}

memset(ptr, 0, size);

// Lock the memory to ensure it's pinned
if (mlock(ptr, size) != 0) {
std::cerr << "[allocMem] mlock failed: " << strerror(errno)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fall through while failed. because unit test always get this error, maybe due to CAP_IPC_LOCK or other reason. we get ENOMEM when exec mlock:
RuntimeError: [allocMem] mlock failed: Cannot allocate memory

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ulimit of lock memory may change

<< " (errno=" << errno
<< "). Continuing without pinned memory.\n";
// Continue without pinning in the environments where mlock is restricted.
// This preserves allocations but degrades guaranteed pinning semantics.
}

// Check if already tracked
if (this->allocatedMap.count(ptr)) {
return &this->allocatedMap[ptr];
}

this->allocatedMap.emplace(
ptr, AllocatedMemoryRecord{reinterpret_cast<uintptr_t>(ptr), size});

return &this->allocatedMap[ptr];
}

// Free memory allocated by allocMem
void HostRegisteredMemoryManager::freeMem(void *hostPtr) {
LMCACHE_ASCEND_CHECK(hostPtr != nullptr, "Error: hostPtr cannot be null.");
const std::unique_lock<std::shared_mutex> guard(this->allocMux);

auto it = this->allocatedMap.find(hostPtr);
if (it == this->allocatedMap.end()) {
throw std::runtime_error("[freeMem] pointer not found in memory manager");
}

size_t size = it->second.buffSize;

// Unmap the memory
int err = munmap(hostPtr, size);
if (err != 0) {
throw std::runtime_error(std::string("[freeMem] munmap failed: ") +
strerror(errno));
}

// Remove from map
this->allocatedMap.erase(it);
}

/*
* For now we only do a linear search as we probably won't have a long list
* of ptrs we go through each record and check whether we are in range, if so we
Expand All @@ -188,11 +249,11 @@ void *HostRegisteredMemoryManager::getDevicePtr(void *hostPtr) {
if (hostPtr == nullptr) {
return nullptr;
}
const std::shared_lock<std::shared_mutex> guard(this->mux);
const std::shared_lock<std::shared_mutex> guard(this->regMux);

const uintptr_t hostAddrPtr = reinterpret_cast<uintptr_t>(hostPtr);

for (const auto &pair : this->allocatedMap) {
for (const auto &pair : this->registeredMap) {
const RegisteredMemoryRecord &record = pair.second;

if (hostAddrPtr >= record.ptr &&
Expand All @@ -212,11 +273,11 @@ size_t HostRegisteredMemoryManager::getRecordSize(void *hostPtr) {
if (hostPtr == nullptr) {
return 0;
}
const std::shared_lock<std::shared_mutex> guard(this->mux);
const std::shared_lock<std::shared_mutex> guard(this->regMux);

const uintptr_t hostAddrPtr = reinterpret_cast<uintptr_t>(hostPtr);

for (const auto &pair : this->allocatedMap) {
for (const auto &pair : this->registeredMap) {
const RegisteredMemoryRecord &record = pair.second;

if (hostAddrPtr >= record.ptr &&
Expand Down Expand Up @@ -394,3 +455,14 @@ void *get_device_ptr(void *ptr) {
auto &hmm = lmc::HostRegisteredMemoryManager::GetInstance();
return hmm.getDevicePtr(ptr);
};

void *alloc_mem(size_t size) {
auto &hmm = lmc::HostRegisteredMemoryManager::GetInstance();
return reinterpret_cast<void *>(hmm.allocMem(size)->ptr);
}

// Generic memory deallocation
void free_mem(void *ptr) {
auto &hmm = lmc::HostRegisteredMemoryManager::GetInstance();
hmm.freeMem(ptr);
}
19 changes: 17 additions & 2 deletions csrc/common/managed_mem.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ struct RegisteredMemoryRecord {
int32_t device;
};

struct AllocatedMemoryRecord {
uintptr_t ptr;
size_t buffSize;
};

/*
* We are not responsible for acl init and ctx initialization,
* we assume the user responsible for ctx initialization
Expand All @@ -28,8 +33,10 @@ class HostRegisteredMemoryManager {
HostRegisteredMemoryManager &
operator=(HostRegisteredMemoryManager &&) = delete;

std::map<void *, RegisteredMemoryRecord> allocatedMap;
mutable std::shared_mutex mux;
std::map<void *, RegisteredMemoryRecord> registeredMap;
std::map<void *, AllocatedMemoryRecord> allocatedMap;
mutable std::shared_mutex regMux; // Lock for registeredMap
mutable std::shared_mutex allocMux; // Lock for allocatedMap

public:
static HostRegisteredMemoryManager &GetInstance() {
Expand Down Expand Up @@ -59,6 +66,10 @@ class HostRegisteredMemoryManager {
void *getDevicePtr(void *hostPtr);
size_t getRecordSize(void *hostPtr);
void unregisterAll();

// Track memory allocations
AllocatedMemoryRecord *allocMem(size_t size);
void freeMem(void *hostPtr);
};

std::string get_driver_version();
Expand All @@ -72,5 +83,9 @@ void *register_ptr(void *ptr, size_t size);
int unregister_ptr(void *ptr);
void *register_mapping(void *hostPtr, void *devPtr, size_t size);

// Generic memory allocation functions
void *alloc_mem(size_t size);
void free_mem(void *ptr);

// Takes in input a host pointer, returns the corresponding device pointer
void *get_device_ptr(void *ptr);
36 changes: 22 additions & 14 deletions csrc/common/mem_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
#include <sys/mman.h>

uintptr_t alloc_pinned_ptr(std::size_t size, unsigned int flags) {
void *ptr = nullptr;
// no flags
aclError err = aclrtMallocHost(&ptr, size);
if (err != ACL_SUCCESS) {
throw std::runtime_error("aclrtMallocHost failed: " + std::to_string(err));
}
void *ptr = alloc_mem(size);

const char *socVersion = aclrtGetSocName();

Expand All @@ -26,7 +21,7 @@ uintptr_t alloc_pinned_ptr(std::size_t size, unsigned int flags) {
// not 310p
auto devPtr = register_ptr(ptr, size);
if (devPtr == nullptr) {
free_pinned_ptr(reinterpret_cast<uintptr_t>(ptr));
free_mem(ptr);
throw std::runtime_error("register ptr failed");
}
}
Expand All @@ -35,11 +30,9 @@ uintptr_t alloc_pinned_ptr(std::size_t size, unsigned int flags) {
}

void free_pinned_ptr(uintptr_t ptr) {
unregister_ptr(reinterpret_cast<void *>(ptr));
aclError err = aclrtFreeHost(reinterpret_cast<void *>(ptr));
if (err != ACL_SUCCESS) {
throw std::runtime_error("aclrtFreeHost failed: " + std::to_string(err));
}
void *vptr = reinterpret_cast<void *>(ptr);
unregister_ptr(vptr);
free_mem(vptr);
}

/*
Expand All @@ -49,7 +42,8 @@ uintptr_t alloc_pinned_numa_ptr(std::size_t size, int node) {
void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
if (ptr == MAP_FAILED) {
throw std::runtime_error(std::string("mmap failed: ") + strerror(errno));
throw std::runtime_error(
std::string("[alloc_pinned_numa_ptr] mmap failed: ") + strerror(errno));
}

// Maximum of 64 numa nodes
Expand All @@ -59,7 +53,17 @@ uintptr_t alloc_pinned_numa_ptr(std::size_t size, int node) {
MPOL_MF_MOVE | MPOL_MF_STRICT);
if (err != 0) {
munmap(ptr, size);
throw std::runtime_error(std::string("mbind failed: ") + strerror(errno));
throw std::runtime_error(
std::string("[alloc_pinned_numa_ptr] mbind failed: ") +
strerror(errno));
}

// In kernels 5.10 and earlier, the aclrtHostRegister requires pinned memory
if (mlock(ptr, size) != 0) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should after memset.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. @chloroethylene

munmap(ptr, size);
throw std::runtime_error(
std::string("[alloc_pinned_numa_ptr] mlock failed: ") +
strerror(errno));
}

memset(ptr, 0, size);
Expand All @@ -86,8 +90,12 @@ uintptr_t alloc_pinned_numa_ptr(std::size_t size, int node) {
void free_pinned_numa_ptr(uintptr_t p, std::size_t size) {
void *ptr = reinterpret_cast<void *>(p);

// Unregister the pointer
auto unRegErr = unregister_ptr(ptr);

// Unmap the memory
auto unMapErr = munmap(ptr, size);

if (unRegErr) {
throw std::runtime_error("unregister_ptr failed: " +
std::to_string(unRegErr));
Comment on lines 94 to 101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

In free_pinned_numa_ptr, munmap is called to unmap memory after attempting to unregister the pointer. If unregister_ptr fails, the function throws a std::runtime_error after munmap has already been executed. If the caller of free_pinned_numa_ptr attempts to retry the operation upon failure, it will result in a double munmap on the same pointer. Double munmap is a security vulnerability that can lead to memory corruption by unmapping memory that might have been re-allocated to another part of the process. The logic should ensure that the exception is thrown in a way that doesn't lead to unsafe retries, or that the memory is only unmapped if unregistration succeeds (if that's the intended semantics).

Expand Down
Loading