diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index db68e40f5..696892421 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -600,6 +600,12 @@ class Context : public std::enable_shared_from_this { /// @return The newly created endpoint. Endpoint createEndpoint(EndpointConfig config); + std::shared_ptr get(std::string name) {return nullptr;} + void set(std::string name, std::shared_ptr value) {} + + private: + Context(); + /// Establish a connection between two endpoints. While this method immediately returns a connection object, the /// connection is only safe to use after the corresponding connection on the remote endpoint has been established. /// This method must be called on both endpoints to establish a connection. @@ -609,14 +615,12 @@ class Context : public std::enable_shared_from_this { /// @return A shared pointer to the connection. std::shared_ptr connect(Endpoint localEndpoint, Endpoint remoteEndpoint); - private: - Context(); - struct Impl; std::unique_ptr pimpl_; friend class RegisteredMemory; friend class Endpoint; + friend class Communicator; }; /// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user. @@ -848,7 +852,8 @@ class Communicator { /// @param tag The tag to use for identifying the send and receive. /// @return A future of shared pointer to the connection. /// - std::shared_future> connect(EndpointConfig localConfig, int remoteRank, int tag = 0); + std::shared_future> connect(EndpointConfig localConfig, int remoteRank, int tag = 0, + std::string connName = "core"); [[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std:: shared_future> diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 964c2d92b..42fd2d046 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -180,8 +180,7 @@ void register_core(nb::module_& m) { return self->registerMemory((void*)ptr, size, transports); }, nb::arg("ptr"), nb::arg("size"), nb::arg("transports")) - .def("create_endpoint", &Context::createEndpoint, nb::arg("config")) - .def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint")); + .def("create_endpoint", &Context::createEndpoint, nb::arg("config")); nb::class_(m, "SemaphoreStub") .def(nb::init>(), nb::arg("connection")) @@ -213,9 +212,9 @@ void register_core(nb::module_& m) { .def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0) .def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0) .def("connect", - static_cast> (Communicator::*)(EndpointConfig, int, int)>( - &Communicator::connect), - nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0) + static_cast> (Communicator::*)( + EndpointConfig, int, int, std::string)>(&Communicator::connect), + nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0, nb::arg("connName") = "core") .def( "connect", [](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) { diff --git a/src/communicator.cc b/src/communicator.cc index 305087a55..b87728a6a 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -4,6 +4,7 @@ #include "communicator.hpp" #include "api.h" +#include "connection.hpp" #include "debug.h" namespace mscclpp { @@ -15,6 +16,10 @@ Communicator::Impl::Impl(std::shared_ptr bootstrap, std::shared_ptr context, Endpoint localEndpoint, Endpoint remoteEndpoint) { + return context->connect(localEndpoint, remoteEndpoint); + }); } void Communicator::Impl::setLastRecvItem(int remoteRank, int tag, std::shared_ptr item) { @@ -100,7 +105,8 @@ MSCCLPP_API_CPP std::shared_future Communicator::recvMemory(in } MSCCLPP_API_CPP std::shared_future> Communicator::connect(EndpointConfig localConfig, - int remoteRank, int tag) { + int remoteRank, int tag, + std::string connName) { auto localEndpoint = context()->createEndpoint(localConfig); if (remoteRank == bootstrap()->getRank()) { @@ -115,9 +121,9 @@ MSCCLPP_API_CPP std::shared_future> Communicator::co bootstrap()->send(localEndpoint.serialize(), remoteRank, tag); - auto future = - std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag), - localEndpoint = std::move(localEndpoint)]() mutable { + auto future = std::async( + std::launch::deferred, [this, remoteRank, tag, connName, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag), + localEndpoint = std::move(localEndpoint)]() mutable { if (lastRecvItem) { // Recursive call to the previous receive items lastRecvItem->wait(); @@ -125,7 +131,7 @@ MSCCLPP_API_CPP std::shared_future> Communicator::co std::vector data; bootstrap()->recv(data, remoteRank, tag); auto remoteEndpoint = Endpoint::deserialize(data); - auto connection = context()->connect(localEndpoint, remoteEndpoint); + auto connection = ConnectionFactory::createConnection(connName, context(), localEndpoint, remoteEndpoint); pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag}; return connection; }); diff --git a/src/ext/connection/connection.cc b/src/ext/connection/connection.cc new file mode 100644 index 000000000..98dda02b2 --- /dev/null +++ b/src/ext/connection/connection.cc @@ -0,0 +1,25 @@ +#include "ext/connection/connection.hpp" + +#include "connection.hpp" + +namespace mscclpp { + + +void IndirectConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) { + if (dstOffset + size > dst.size() || srcOffset + size > src.size()) { + throw Error("IndirectionConnection::write out of bounds", ErrorCode::InvalidUsage); + } + scheduler_ptr_->sched(dst, dstOffset, src, srcOffset, size); +} + +void IndirectConnection::flush(int64_t timeoutUsec) { + if (timeoutUsec != -1) { + throw std::runtime_error("IndirectConnection does not support timeout in flush"); + } + scheduler_ptr_->sync(); +} +Transport IndirectConnection::transport() const { return Transport::CudaIpc; } +Transport IndirectConnection::remoteTransport() const { return Transport::CudaIpc; } + +} // namespace mscclpp \ No newline at end of file diff --git a/src/ext/connection/example.cc b/src/ext/connection/example.cc new file mode 100644 index 000000000..3f2bea984 --- /dev/null +++ b/src/ext/connection/example.cc @@ -0,0 +1,17 @@ +#include "connection.hpp" +#include "ext/connection/connection.hpp" + +void test() { + auto context = mscclpp::Context::create(); + auto localEndpoint = context->createEndpoint({mscclpp::Transport::CudaIpc}); + auto remoteEndpoint = context->createEndpoint({mscclpp::Transport::CudaIpc}); + mscclpp::Device fwd(mscclpp::DeviceType::GPU, 2); + std::shared_ptr scheduler = std::make_shared(context, fwd); + context->set("scheduler", scheduler); + mscclpp::ConnectionFactory::registerConnection( + "indirect", [context](std::shared_ptr ctx, mscclpp::Endpoint local, mscclpp::Endpoint remote) { + std::shared_ptr scheduler = std::static_pointer_cast(context->get("scheduler")); + return std::make_shared(ctx, local, scheduler); + }); + auto connection = mscclpp::ConnectionFactory::createConnection("indirect", context, localEndpoint, remoteEndpoint); +} \ No newline at end of file diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 5539479ea..87736a9d0 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -15,6 +15,30 @@ namespace mscclpp { +class ConnectionFactory { + private: + using ConnectionCreator = std::function(std::shared_ptr, Endpoint, Endpoint)>; + static std::unordered_map& getRegistry() { + static std::unordered_map registry; + return registry; + } + + public: + static void registerConnection(const std::string& connName, ConnectionCreator creator) { + getRegistry()[connName] = creator; + } + + static std::shared_ptr createConnection(const std::string& connName, std::shared_ptr context, + Endpoint localEndpoint, Endpoint remoteEndpoint) { + auto& registry = getRegistry(); + auto it = registry.find(connName); + if (it != registry.end()) { + return it->second(context, localEndpoint, remoteEndpoint); + } + throw std::runtime_error("Unknown connection type: " + connName); + } +}; + class CudaIpcConnection : public Connection { private: std::shared_ptr stream_; diff --git a/src/include/ext/connection/connection.hpp b/src/include/ext/connection/connection.hpp new file mode 100644 index 000000000..67e046cd3 --- /dev/null +++ b/src/include/ext/connection/connection.hpp @@ -0,0 +1,52 @@ +#include "mscclpp/core.hpp" +#include "mscclpp/gpu_utils.hpp" + +namespace mscclpp { + +class ConnectionScheduler { + public: + virtual void sched(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) = 0; + virtual void sync() = 0; +}; + +class DefaultConnectionScheduler : public ConnectionScheduler { + public: + DefaultConnectionScheduler(std::shared_ptr context, Device device) : context_(context), device_(device) {} + + void sched(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override { + // Implementation for scheduling tasks + } + + void sync() override { + // Implementation for synchronizing tasks + } + + private: + std::shared_ptr context_; + Device device_; +}; + +class IndirectConnection : public Connection { + std::shared_ptr scheduler_ptr_; + + public: + IndirectConnection(std::shared_ptr context, Endpoint localEndpoint, + std::shared_ptr scheduler) + : Connection(context, localEndpoint), scheduler_ptr_(scheduler) { + if (scheduler_ptr_ == nullptr) { + throw std::runtime_error("Scheduler not set in context"); + } + } + void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, + uint64_t size) override; + void flush(int64_t timeoutUsec = -1) override; + Transport transport() const override; + Transport remoteTransport() const override; + + virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t *src, uint64_t newValue) override { + throw std::runtime_error("IndirectConnection does not support updateAndSync"); + } +}; +} // namespace mscclpp \ No newline at end of file