@@ -735,15 +735,16 @@ int TransferEnginePy::transferCheckStatus(batch_id_t batch_id) {
735735}
736736
737737int TransferEnginePy::batchRegisterMemory (
738- std::vector<uintptr_t > buffer_addresses, std::vector<size_t > capacities) {
738+ std::vector<uintptr_t > buffer_addresses, std::vector<size_t > capacities,
739+ const std::string& location) {
739740 pybind11::gil_scoped_release release;
740741 auto batch_size = buffer_addresses.size ();
741742 std::vector<BufferEntry> buffers;
742743 for (size_t i = 0 ; i < batch_size; i++) {
743744 buffers.push_back (
744745 BufferEntry{(void *)buffer_addresses[i], capacities[i]});
745746 }
746- return engine_->registerLocalMemoryBatch (buffers, kWildcardLocation );
747+ return engine_->registerLocalMemoryBatch (buffers, location );
747748}
748749
749750int TransferEnginePy::batchUnregisterMemory (
@@ -757,9 +758,10 @@ int TransferEnginePy::batchUnregisterMemory(
757758 return engine_->unregisterLocalMemoryBatch (buffers);
758759}
759760
760- int TransferEnginePy::registerMemory (uintptr_t buffer_addr, size_t capacity) {
761+ int TransferEnginePy::registerMemory (uintptr_t buffer_addr, size_t capacity,
762+ const std::string& location) {
761763 char * buffer = reinterpret_cast <char *>(buffer_addr);
762- return engine_->registerLocalMemory (buffer, capacity);
764+ return engine_->registerLocalMemory (buffer, capacity, location );
763765}
764766
765767int TransferEnginePy::unregisterMemory (uintptr_t buffer_addr) {
@@ -1100,10 +1102,14 @@ PYBIND11_MODULE(engine, m) {
11001102 .def (" write_bytes_to_buffer" , &TransferEnginePy::writeBytesToBuffer)
11011103 .def (" read_bytes_from_buffer" ,
11021104 &TransferEnginePy::readBytesFromBuffer)
1103- .def (" register_memory" , &TransferEnginePy::registerMemory)
1105+ .def (" register_memory" , &TransferEnginePy::registerMemory,
1106+ py::arg (" buffer_addr" ), py::arg (" capacity" ),
1107+ py::arg (" location" ) = kWildcardLocation )
11041108 .def (" unregister_memory" , &TransferEnginePy::unregisterMemory)
11051109 .def (" batch_register_memory" ,
1106- &TransferEnginePy::batchRegisterMemory)
1110+ &TransferEnginePy::batchRegisterMemory,
1111+ py::arg (" buffer_addresses" ), py::arg (" capacities" ),
1112+ py::arg (" location" ) = kWildcardLocation )
11071113 .def (" batch_unregister_memory" ,
11081114 &TransferEnginePy::batchUnregisterMemory)
11091115 .def (" get_local_topology" , &TransferEnginePy::getLocalTopology,
0 commit comments