From 63efccc2d9983401049f3486120f9b9814d38de1 Mon Sep 17 00:00:00 2001 From: Caio Date: Tue, 17 Dec 2024 07:43:03 +0000 Subject: [PATCH 1/8] flushing automatically when reach the inflight request limit --- include/mscclpp/proxy_channel.hpp | 3 +++ src/proxy_channel.cc | 11 ++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index 91c67a2dd..420d31284 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -11,6 +11,8 @@ namespace mscclpp { +constexpr int MAX_INFLIGHT_REQUEST = 500; + struct BaseProxyChannel; struct ProxyChannel; @@ -72,6 +74,7 @@ class ProxyService : public BaseProxyService { std::vector memories_; std::shared_ptr proxy_; int deviceNumaNode; + int inflightRequests; void bindThread(); diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index f231e73a1..5bdcff4c5 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -24,6 +24,7 @@ MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize) int cudaDevice; MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice)); deviceNumaNode = getDeviceNumaNode(cudaDevice); + inflightRequests = 0; } MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, @@ -70,6 +71,13 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { std::shared_ptr semaphore = semaphores_[trigger->fields.chanId]; auto result = ProxyHandlerResult::Continue; + + inflightRequests++; + if(!(trigger->fields.type & TriggerSync) && inflightRequests > MAX_INFLIGHT_REQUEST){ + semaphore->connection()->flush(); + result = ProxyHandlerResult::FlushFifoTailAndContinue; + inflightRequests = 0; + } if (trigger->fields.type & TriggerData) { RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; @@ -84,7 +92,8 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { if (trigger->fields.type & TriggerSync) { semaphore->connection()->flush(); - result = ProxyHandlerResult::FlushFifoTailAndContinue; + result = ProxyHandlerResult::FlushFifoTailAndContinue;\ + inflightRequests = 0; } return result; From 56faf69e868c8ba25b9d482a9cb6055ae0000dd1 Mon Sep 17 00:00:00 2001 From: Caio Date: Tue, 17 Dec 2024 07:54:02 +0000 Subject: [PATCH 2/8] formating code --- src/proxy_channel.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index 5bdcff4c5..1ff398378 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -71,9 +71,9 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { std::shared_ptr semaphore = semaphores_[trigger->fields.chanId]; auto result = ProxyHandlerResult::Continue; - + inflightRequests++; - if(!(trigger->fields.type & TriggerSync) && inflightRequests > MAX_INFLIGHT_REQUEST){ + if (!(trigger->fields.type & TriggerSync) && inflightRequests > MAX_INFLIGHT_REQUEST) { semaphore->connection()->flush(); result = ProxyHandlerResult::FlushFifoTailAndContinue; inflightRequests = 0; @@ -92,7 +92,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { if (trigger->fields.type & TriggerSync) { semaphore->connection()->flush(); - result = ProxyHandlerResult::FlushFifoTailAndContinue;\ + result = ProxyHandlerResult::FlushFifoTailAndContinue; inflightRequests = 0; } From 8d932cfac3b90519b252116cdc65e99ad451fe5f Mon Sep 17 00:00:00 2001 From: Caio Date: Fri, 3 Jan 2025 10:07:21 +0000 Subject: [PATCH 3/8] making the maximum inflight requests configurable by the user --- include/mscclpp/core.hpp | 21 ++++++++++++++++----- include/mscclpp/proxy_channel.hpp | 2 +- src/connection.cc | 13 ++++++++++--- src/endpoint.cc | 4 +++- src/include/endpoint.hpp | 1 + src/proxy_channel.cc | 15 +++++---------- 6 files changed, 36 insertions(+), 20 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index b2758d9e8..10497649e 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -33,7 +33,7 @@ std::string version(); /// Base class for bootstraps. class Bootstrap { public: - Bootstrap(){}; + Bootstrap() {}; virtual ~Bootstrap() = default; virtual int getRank() = 0; virtual int getNranks() = 0; @@ -388,6 +388,11 @@ class Endpoint { /// @return The transport used. Transport transport(); + /// Get the max inflight requests. + /// + /// @return max inflight requests. + int maxInflightRequests(); + /// Serialize the Endpoint object to a vector of characters. /// /// @return A vector of characters representing the serialized Endpoint object. @@ -416,6 +421,10 @@ class Endpoint { /// Represents a connection between two processes. class Connection { public: + /// Constructor. + /// @param maxInflightRequests The maximum number of inflight requests. + Connection(int maxInflightRequests) : maxInflightRequests(maxInflightRequests) {}; + virtual ~Connection() = default; /// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory. @@ -454,10 +463,13 @@ class Connection { /// @return name of @ref transport() -> @ref remoteTransport() std::string getTransportName(); + int getMaxInflightRequest(); + protected: // Internal methods for getting implementation pointers. static std::shared_ptr getImpl(RegisteredMemory& memory); static std::shared_ptr getImpl(Endpoint& memory); + int maxInflightRequests; }; /// Used to configure an endpoint. @@ -472,14 +484,13 @@ struct EndpointConfig { int ibMaxCqPollNum = DefaultMaxCqPollNum; int ibMaxSendWr = DefaultMaxSendWr; int ibMaxWrPerSend = DefaultMaxWrPerSend; - - /// Default constructor. Sets transport to Transport::Unknown. - EndpointConfig() : transport(Transport::Unknown) {} + int maxInflightRequests; /// Constructor that takes a transport and sets the other fields to their default values. /// /// @param transport The transport to use. - EndpointConfig(Transport transport) : transport(transport) {} + EndpointConfig(Transport transport = Transport::Unknown, int maxInflightRequests = -1) + : transport(transport), maxInflightRequests(maxInflightRequests) {} }; /// Represents a context for communication. This provides a low-level interface for forming connections in use-cases diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index 420d31284..e2ad2528c 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -74,7 +74,7 @@ class ProxyService : public BaseProxyService { std::vector memories_; std::shared_ptr proxy_; int deviceNumaNode; - int inflightRequests; + std::unordered_map, int> inflightRequests; void bindThread(); diff --git a/src/connection.cc b/src/connection.cc index 6a5b554d5..73dd4d8fa 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -36,10 +36,13 @@ std::string Connection::getTransportName() { TransportNames[static_cast(this->remoteTransport())]; } +int Connection::getMaxInflightRequest() { return maxInflightRequests; } + // CudaIpcConnection CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream) - : stream_(stream) { + : stream_(stream), + Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX) { if (localEndpoint.transport() != Transport::CudaIpc) { throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage); } @@ -121,7 +124,8 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context) : transport_(localEndpoint.transport()), remoteTransport_(remoteEndpoint.transport()), - dummyAtomicSource_(std::make_unique(0)) { + dummyAtomicSource_(std::make_unique(0)), + Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : 1024) { qp = getImpl(localEndpoint)->ibQp_; qp->rtr(getImpl(remoteEndpoint)->ibQpInfo_); qp->rts(); @@ -231,7 +235,10 @@ void IBConnection::flush(int64_t timeoutUsec) { EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize) - : abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) { + : abortFlag_(0), + sendBufferSize_(sendBufferSize), + recvBufferSize_(recvBufferSize), + Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX) { // Validating Transport Protocol if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) { throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage); diff --git a/src/endpoint.cc b/src/endpoint.cc index 015d51a60..b011c95bf 100644 --- a/src/endpoint.cc +++ b/src/endpoint.cc @@ -13,7 +13,7 @@ namespace mscclpp { Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl) - : transport_(config.transport), hostHash_(getHostHash()) { + : transport_(config.transport), hostHash_(getHostHash()), maxInflightRequests_(config.maxInflightRequests) { if (AllIBTransports.has(transport_)) { ibLocal_ = true; ibQp_ = contextImpl.getIbContext(transport_) @@ -34,6 +34,8 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl) MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; } +MSCCLPP_API_CPP int Endpoint::maxInflightRequests() { return pimpl_->maxInflightRequests_; } + MSCCLPP_API_CPP std::vector Endpoint::serialize() { std::vector data; std::copy_n(reinterpret_cast(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data)); diff --git a/src/include/endpoint.hpp b/src/include/endpoint.hpp index 734a6c1bd..063114b61 100644 --- a/src/include/endpoint.hpp +++ b/src/include/endpoint.hpp @@ -20,6 +20,7 @@ struct Endpoint::Impl { Transport transport_; uint64_t hostHash_; + int maxInflightRequests_; // The following are only used for IB and are undefined for other transports. bool ibLocal_; diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index 1ff398378..a251a339b 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -24,7 +24,6 @@ MSCCLPP_API_CPP ProxyService::ProxyService(size_t fifoSize) int cudaDevice; MSCCLPP_CUDATHROW(cudaGetDevice(&cudaDevice)); deviceNumaNode = getDeviceNumaNode(cudaDevice); - inflightRequests = 0; } MSCCLPP_API_CPP SemaphoreId ProxyService::buildAndAddSemaphore(Communicator& communicator, @@ -72,28 +71,24 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { auto result = ProxyHandlerResult::Continue; - inflightRequests++; - if (!(trigger->fields.type & TriggerSync) && inflightRequests > MAX_INFLIGHT_REQUEST) { - semaphore->connection()->flush(); - result = ProxyHandlerResult::FlushFifoTailAndContinue; - inflightRequests = 0; - } - if (trigger->fields.type & TriggerData) { RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; RegisteredMemory& src = memories_[trigger->fields.srcMemoryId]; semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset, trigger->fields.size); + inflightRequests[semaphore->connection()]++; } if (trigger->fields.type & TriggerFlag) { semaphore->signal(); + inflightRequests[semaphore->connection()]++; } - if (trigger->fields.type & TriggerSync) { + if (trigger->fields.type & TriggerSync || + inflightRequests[semaphore->connection()] > semaphore->connection()->getMaxInflightRequest()) { semaphore->connection()->flush(); result = ProxyHandlerResult::FlushFifoTailAndContinue; - inflightRequests = 0; + inflightRequests[semaphore->connection()] = 0; } return result; From 5749555ed4445d9f68cb847f6ce0b678f9542bf3 Mon Sep 17 00:00:00 2001 From: Caio Date: Fri, 3 Jan 2025 13:11:39 +0000 Subject: [PATCH 4/8] wip --- include/mscclpp/proxy_channel.hpp | 2 -- src/connection.cc | 17 +++++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/include/mscclpp/proxy_channel.hpp b/include/mscclpp/proxy_channel.hpp index e2ad2528c..4f2978f75 100644 --- a/include/mscclpp/proxy_channel.hpp +++ b/include/mscclpp/proxy_channel.hpp @@ -11,8 +11,6 @@ namespace mscclpp { -constexpr int MAX_INFLIGHT_REQUEST = 500; - struct BaseProxyChannel; struct ProxyChannel; diff --git a/src/connection.cc b/src/connection.cc index 73dd4d8fa..e27ba7830 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -41,8 +41,8 @@ int Connection::getMaxInflightRequest() { return maxInflightRequests; } // CudaIpcConnection CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream) - : stream_(stream), - Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX) { + : Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX), + stream_(stream) { if (localEndpoint.transport() != Transport::CudaIpc) { throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage); } @@ -122,10 +122,11 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { // IBConnection IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context) - : transport_(localEndpoint.transport()), + : Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() + : EndpointConfig::DefaultMaxCqPollNum), + transport_(localEndpoint.transport()), remoteTransport_(remoteEndpoint.transport()), - dummyAtomicSource_(std::make_unique(0)), - Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : 1024) { + dummyAtomicSource_(std::make_unique(0)) { qp = getImpl(localEndpoint)->ibQp_; qp->rtr(getImpl(remoteEndpoint)->ibQpInfo_); qp->rts(); @@ -235,10 +236,10 @@ void IBConnection::flush(int64_t timeoutUsec) { EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize) - : abortFlag_(0), + : Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX), + abortFlag_(0), sendBufferSize_(sendBufferSize), - recvBufferSize_(recvBufferSize), - Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX) { + recvBufferSize_(recvBufferSize) { // Validating Transport Protocol if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) { throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage); From 7f595c54b70da1b656e4ec9e20107d699048bf9f Mon Sep 17 00:00:00 2001 From: Caio Date: Sat, 4 Jan 2025 00:07:56 +0000 Subject: [PATCH 5/8] design adjustments --- include/mscclpp/core.hpp | 41 +++++++++++++++++++++++++++------------- src/connection.cc | 10 +++++----- src/endpoint.cc | 4 ++-- src/include/endpoint.hpp | 2 +- src/proxy_channel.cc | 3 ++- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index adbfac31f..e0979cf19 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -391,7 +391,7 @@ class Endpoint { /// Get the max inflight requests. /// /// @return max inflight requests. - int maxInflightRequests(); + int maxWriteQueueSize(); /// Serialize the Endpoint object to a vector of characters. /// @@ -422,8 +422,8 @@ class Endpoint { class Connection { public: /// Constructor. - /// @param maxInflightRequests The maximum number of inflight requests. - Connection(int maxInflightRequests) : maxInflightRequests(maxInflightRequests) {}; + /// @param maxWriteQueueSize The maximum number of write requests that can be queued. + Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize) {}; virtual ~Connection() = default; @@ -463,13 +463,16 @@ class Connection { /// @return name of @ref transport() -> @ref remoteTransport() std::string getTransportName(); - int getMaxInflightRequest(); + /// Get the maximum write queue size + /// + /// @return The maximum number of write requests that can be queued. + int getMaxWriteQueueSize(); protected: // Internal methods for getting implementation pointers. static std::shared_ptr getImpl(RegisteredMemory& memory); static std::shared_ptr getImpl(Endpoint& memory); - int maxInflightRequests; + int maxWriteQueueSize; }; /// Used to configure an endpoint. @@ -480,17 +483,29 @@ struct EndpointConfig { static const int DefaultMaxWrPerSend = 64; Transport transport; - int ibMaxCqSize = DefaultMaxCqSize; - int ibMaxCqPollNum = DefaultMaxCqPollNum; - int ibMaxSendWr = DefaultMaxSendWr; - int ibMaxWrPerSend = DefaultMaxWrPerSend; - int maxInflightRequests; + int ibMaxCqSize; + int ibMaxCqPollNum; + int ibMaxSendWr; + int ibMaxWrPerSend; + int maxWriteQueueSize; - /// Constructor that takes a transport and sets the other fields to their default values. + /// Constructor that takes a transport and sets the other fields to their default values or provided values. /// /// @param transport The transport to use. - EndpointConfig(Transport transport = Transport::Unknown, int maxInflightRequests = -1) - : transport(transport), maxInflightRequests(maxInflightRequests) {} + /// @param ibMaxCqSize The maximum completion queue size. + /// @param ibMaxCqPollNum The maximum completion queue poll number. + /// @param ibMaxSendWr The maximum send work requests. + /// @param ibMaxWrPerSend The maximum work requests per send. + /// @param maxWriteQueueSize The maximum write queue size. + EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize, + int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr, + int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1) + : transport(transport), + ibMaxCqSize(ibMaxCqSize), + ibMaxCqPollNum(ibMaxCqPollNum), + ibMaxSendWr(ibMaxSendWr), + ibMaxWrPerSend(ibMaxWrPerSend), + maxWriteQueueSize(maxWriteQueueSize) {} }; /// Represents a context for communication. This provides a low-level interface for forming connections in use-cases diff --git a/src/connection.cc b/src/connection.cc index e27ba7830..2f3c200ef 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -36,12 +36,12 @@ std::string Connection::getTransportName() { TransportNames[static_cast(this->remoteTransport())]; } -int Connection::getMaxInflightRequest() { return maxInflightRequests; } +int Connection::getMaxWriteQueueSize() { return maxWriteQueueSize; } // CudaIpcConnection CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream) - : Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX), + : Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) { if (localEndpoint.transport() != Transport::CudaIpc) { throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage); @@ -122,8 +122,8 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { // IBConnection IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context) - : Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() - : EndpointConfig::DefaultMaxCqPollNum), + : Connection(localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize() + : EndpointConfig::DefaultMaxCqSize), transport_(localEndpoint.transport()), remoteTransport_(remoteEndpoint.transport()), dummyAtomicSource_(std::make_unique(0)) { @@ -236,7 +236,7 @@ void IBConnection::flush(int64_t timeoutUsec) { EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize, uint64_t recvBufferSize) - : Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX), + : Connection(localEndpoint.maxWriteQueueSize()), abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) { diff --git a/src/endpoint.cc b/src/endpoint.cc index b011c95bf..5df06ff01 100644 --- a/src/endpoint.cc +++ b/src/endpoint.cc @@ -13,7 +13,7 @@ namespace mscclpp { Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl) - : transport_(config.transport), hostHash_(getHostHash()), maxInflightRequests_(config.maxInflightRequests) { + : transport_(config.transport), hostHash_(getHostHash()), maxWriteQueueSize_(config.maxWriteQueueSize) { if (AllIBTransports.has(transport_)) { ibLocal_ = true; ibQp_ = contextImpl.getIbContext(transport_) @@ -34,7 +34,7 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl) MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; } -MSCCLPP_API_CPP int Endpoint::maxInflightRequests() { return pimpl_->maxInflightRequests_; } +MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() { return pimpl_->maxWriteQueueSize_; } MSCCLPP_API_CPP std::vector Endpoint::serialize() { std::vector data; diff --git a/src/include/endpoint.hpp b/src/include/endpoint.hpp index 063114b61..a91330ffb 100644 --- a/src/include/endpoint.hpp +++ b/src/include/endpoint.hpp @@ -20,7 +20,7 @@ struct Endpoint::Impl { Transport transport_; uint64_t hostHash_; - int maxInflightRequests_; + int maxWriteQueueSize_; // The following are only used for IB and are undefined for other transports. bool ibLocal_; diff --git a/src/proxy_channel.cc b/src/proxy_channel.cc index a251a339b..f2ca00674 100644 --- a/src/proxy_channel.cc +++ b/src/proxy_channel.cc @@ -70,6 +70,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { std::shared_ptr semaphore = semaphores_[trigger->fields.chanId]; auto result = ProxyHandlerResult::Continue; + int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize(); if (trigger->fields.type & TriggerData) { RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId]; @@ -85,7 +86,7 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) { } if (trigger->fields.type & TriggerSync || - inflightRequests[semaphore->connection()] > semaphore->connection()->getMaxInflightRequest()) { + (maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) { semaphore->connection()->flush(); result = ProxyHandlerResult::FlushFifoTailAndContinue; inflightRequests[semaphore->connection()] = 0; From be82957430ee93b4badebc445f56266ecdf670b4 Mon Sep 17 00:00:00 2001 From: Caio Date: Sat, 4 Jan 2025 00:17:40 +0000 Subject: [PATCH 6/8] small adjustments --- include/mscclpp/core.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index e0979cf19..73554cb46 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -388,9 +388,9 @@ class Endpoint { /// @return The transport used. Transport transport(); - /// Get the max inflight requests. + /// Get the maximum write queue size. /// - /// @return max inflight requests. + /// @return The maximum number of write requests that can be queued. int maxWriteQueueSize(); /// Serialize the Endpoint object to a vector of characters. @@ -463,7 +463,7 @@ class Connection { /// @return name of @ref transport() -> @ref remoteTransport() std::string getTransportName(); - /// Get the maximum write queue size + /// Get the maximum write queue size /// /// @return The maximum number of write requests that can be queued. int getMaxWriteQueueSize(); @@ -489,7 +489,7 @@ struct EndpointConfig { int ibMaxWrPerSend; int maxWriteQueueSize; - /// Constructor that takes a transport and sets the other fields to their default values or provided values. + /// Constructor that takes a transport and sets the other fields to their default values. /// /// @param transport The transport to use. /// @param ibMaxCqSize The maximum completion queue size. From 86bdb84728093786ae92db20731ecc0330cc3adb Mon Sep 17 00:00:00 2001 From: Caio Date: Mon, 6 Jan 2025 03:51:25 +0000 Subject: [PATCH 7/8] adjusting formatation --- include/mscclpp/core.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 73554cb46..cfc8ccc87 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -33,7 +33,7 @@ std::string version(); /// Base class for bootstraps. class Bootstrap { public: - Bootstrap() {}; + Bootstrap(){}; virtual ~Bootstrap() = default; virtual int getRank() = 0; virtual int getNranks() = 0; @@ -423,7 +423,7 @@ class Connection { public: /// Constructor. /// @param maxWriteQueueSize The maximum number of write requests that can be queued. - Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize) {}; + Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){}; virtual ~Connection() = default; From 501ca8c41eb6e1ab0560d76f3a60d8b1aca3c6ed Mon Sep 17 00:00:00 2001 From: Caio Date: Mon, 6 Jan 2025 03:55:21 +0000 Subject: [PATCH 8/8] adjusting formatation --- src/connection.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/connection.cc b/src/connection.cc index 2f3c200ef..a078a30da 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -41,8 +41,7 @@ int Connection::getMaxWriteQueueSize() { return maxWriteQueueSize; } // CudaIpcConnection CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream) - : Connection(localEndpoint.maxWriteQueueSize()), - stream_(stream) { + : Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) { if (localEndpoint.transport() != Transport::CudaIpc) { throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage); } @@ -123,7 +122,7 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context) : Connection(localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize() - : EndpointConfig::DefaultMaxCqSize), + : EndpointConfig::DefaultMaxCqSize), transport_(localEndpoint.transport()), remoteTransport_(remoteEndpoint.transport()), dummyAtomicSource_(std::make_unique(0)) {