Skip to content

Commit ca7e5fb

Browse files
A-LiuhaoA-Liuhao
andauthored
python api register_memory support location param (#2191)
Co-authored-by: A-Liuhao <liuhao276@hisilicon.com>
1 parent 97da393 commit ca7e5fb

2 files changed

Lines changed: 16 additions & 8 deletions

File tree

mooncake-integration/transfer_engine/transfer_engine_py.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -735,15 +735,16 @@ int TransferEnginePy::transferCheckStatus(batch_id_t batch_id) {
735735
}
736736

737737
int TransferEnginePy::batchRegisterMemory(
738-
std::vector<uintptr_t> buffer_addresses, std::vector<size_t> capacities) {
738+
std::vector<uintptr_t> buffer_addresses, std::vector<size_t> capacities,
739+
const std::string& location) {
739740
pybind11::gil_scoped_release release;
740741
auto batch_size = buffer_addresses.size();
741742
std::vector<BufferEntry> buffers;
742743
for (size_t i = 0; i < batch_size; i++) {
743744
buffers.push_back(
744745
BufferEntry{(void*)buffer_addresses[i], capacities[i]});
745746
}
746-
return engine_->registerLocalMemoryBatch(buffers, kWildcardLocation);
747+
return engine_->registerLocalMemoryBatch(buffers, location);
747748
}
748749

749750
int TransferEnginePy::batchUnregisterMemory(
@@ -757,9 +758,10 @@ int TransferEnginePy::batchUnregisterMemory(
757758
return engine_->unregisterLocalMemoryBatch(buffers);
758759
}
759760

760-
int TransferEnginePy::registerMemory(uintptr_t buffer_addr, size_t capacity) {
761+
int TransferEnginePy::registerMemory(uintptr_t buffer_addr, size_t capacity,
762+
const std::string& location) {
761763
char* buffer = reinterpret_cast<char*>(buffer_addr);
762-
return engine_->registerLocalMemory(buffer, capacity);
764+
return engine_->registerLocalMemory(buffer, capacity, location);
763765
}
764766

765767
int TransferEnginePy::unregisterMemory(uintptr_t buffer_addr) {
@@ -1100,10 +1102,14 @@ PYBIND11_MODULE(engine, m) {
11001102
.def("write_bytes_to_buffer", &TransferEnginePy::writeBytesToBuffer)
11011103
.def("read_bytes_from_buffer",
11021104
&TransferEnginePy::readBytesFromBuffer)
1103-
.def("register_memory", &TransferEnginePy::registerMemory)
1105+
.def("register_memory", &TransferEnginePy::registerMemory,
1106+
py::arg("buffer_addr"), py::arg("capacity"),
1107+
py::arg("location") = kWildcardLocation)
11041108
.def("unregister_memory", &TransferEnginePy::unregisterMemory)
11051109
.def("batch_register_memory",
1106-
&TransferEnginePy::batchRegisterMemory)
1110+
&TransferEnginePy::batchRegisterMemory,
1111+
py::arg("buffer_addresses"), py::arg("capacities"),
1112+
py::arg("location") = kWildcardLocation)
11071113
.def("batch_unregister_memory",
11081114
&TransferEnginePy::batchUnregisterMemory)
11091115
.def("get_local_topology", &TransferEnginePy::getLocalTopology,

mooncake-integration/transfer_engine/transfer_engine_py.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,15 @@ class TransferEnginePy {
169169
}
170170

171171
// FOR EXPERIMENT ONLY
172-
int registerMemory(uintptr_t buffer_addr, size_t capacity);
172+
int registerMemory(uintptr_t buffer_addr, size_t capacity,
173+
const std::string &location = kWildcardLocation);
173174

174175
// must be called before TransferEnginePy::~TransferEnginePy()
175176
int unregisterMemory(uintptr_t buffer_addr);
176177

177178
int batchRegisterMemory(std::vector<uintptr_t> buffer_addresses,
178-
std::vector<size_t> capacities);
179+
std::vector<size_t> capacities,
180+
const std::string &location = kWildcardLocation);
179181

180182
int batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses);
181183

0 commit comments

Comments
 (0)