-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathmanaged_mem.h
More file actions
91 lines (77 loc) · 3.04 KB
/
managed_mem.h
File metadata and controls
91 lines (77 loc) · 3.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#pragma once
#include <map>
#include <shared_mutex>
#include <string>
namespace lmc {
struct RegisteredMemoryRecord {
uintptr_t ptr;
uintptr_t devptr;
size_t buffSize;
int32_t device;
};
struct AllocatedMemoryRecord {
uintptr_t ptr;
size_t buffSize;
};
/*
* We are not responsible for acl init and ctx initialization,
* we assume the user responsible for ctx initialization
*/
class HostRegisteredMemoryManager {
private:
HostRegisteredMemoryManager();
// Delete copy constructor and assignment operator
HostRegisteredMemoryManager(const HostRegisteredMemoryManager &) = delete;
HostRegisteredMemoryManager &
operator=(const HostRegisteredMemoryManager &) = delete;
HostRegisteredMemoryManager(HostRegisteredMemoryManager &&) = delete;
HostRegisteredMemoryManager &
operator=(HostRegisteredMemoryManager &&) = delete;
std::map<void *, RegisteredMemoryRecord> registeredMap;
std::map<void *, AllocatedMemoryRecord> allocatedMap;
mutable std::shared_mutex regMux; // Lock for registeredMap
mutable std::shared_mutex allocMux; // Lock for allocatedMap
public:
static HostRegisteredMemoryManager &GetInstance() {
static HostRegisteredMemoryManager instance;
return instance;
}
~HostRegisteredMemoryManager();
// Register a pointer through high level APIs (aclrt) return devPtr
// Returns an already existing RegisteredMemoryRecord or the newly created one
// Inputs:
// -hostPtr: host pointer of the allocated memory area to register on device
// -bufferSize: size of the allocated memory area to register on device
RegisteredMemoryRecord *
registerHostPtr(void *hostPtr,
size_t bufferSize); // torch::Tensor& tensor); //
// Register a pointer through low level APIs (hal)
// This should be used for driver versions, where cannot rely on
// aclrtHostRegister() Returns the created RegisteredMemoryRecord Inputs:
// -hostPtr: host pointer of the allocated memory area to register on device
// -bufferSize: size of the allocated memory area to register on device
RegisteredMemoryRecord *halRegisterHostPtr(void *hostPtr, size_t bufferSize);
RegisteredMemoryRecord *registerMappedMem(void *hostPtr, void *devPtr,
size_t bufferSize);
int aclUnregisterHostPtr(void *hostPtr);
int halUnregisterHostPtr(void *hostPtr);
void *getDevicePtr(void *hostPtr);
size_t getRecordSize(void *hostPtr);
void unregisterAll();
// Track memory allocations
AllocatedMemoryRecord *allocMem(size_t size);
void freeMem(void *hostPtr);
};
std::string get_driver_version();
bool is_version_at_least_25(const std::string &version_str);
// Uregisters the malloced hostPtr
void hal_host_unregister_ptr(void *ptr);
} // namespace lmc
void *register_ptr(void *ptr, size_t size);
int unregister_ptr(void *ptr);
void *register_mapping(void *hostPtr, void *devPtr, size_t size);
// Generic memory allocation functions
void *alloc_mem(size_t size);
void free_mem(void *ptr);
// Takes in input a host pointer, returns the corresponding device pointer
void *get_device_ptr(void *ptr);