Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

signal/poll optimization #366

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions include/mscclpp/proxy_channel_device.hpp
Original file line number Diff line number Diff line change
@@ -114,6 +114,11 @@ struct ProxyChannelDeviceHandle {

/// Push a @ref TriggerFlag to the FIFO.
MSCCLPP_DEVICE_INLINE void signal() { fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value); }

MSCCLPP_DEVICE_INLINE void signal(const uint64_t count) {
for (uint64_t i = 0; i < count; ++i)
fifo_.push(ChannelTrigger(TriggerFlag, 0, 0, 0, 0, 1, semaphoreId_).value);
}

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dst The destination memory region.
@@ -165,8 +170,9 @@ struct ProxyChannelDeviceHandle {
}

/// Check if the proxy channel has been signaled.
/// @param max_poll The max number of signals to poll.
/// @return true if the proxy channel has been signaled.
MSCCLPP_DEVICE_INLINE bool poll() { return semaphore_.poll(); }
MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { return semaphore_.poll(max_poll); }

/// Wait for the proxy channel to be signaled.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
@@ -195,7 +201,7 @@ struct SimpleProxyChannelDeviceHandle {
MSCCLPP_DEVICE_INLINE void put(uint64_t offset, uint64_t size) { put(offset, offset, size); }

/// Push a @ref TriggerFlag to the FIFO.
MSCCLPP_DEVICE_INLINE void signal() { proxyChan_.signal(); }
MSCCLPP_DEVICE_INLINE void signal(const uint64_t count = 1) { proxyChan_.signal(count); }

/// Push a @ref TriggerData and a @ref TriggerFlag at the same time to the FIFO.
/// @param dstOffset The offset into the destination memory region.
@@ -229,8 +235,9 @@ struct SimpleProxyChannelDeviceHandle {
MSCCLPP_DEVICE_INLINE void flush() { proxyChan_.flush(); }

/// Check if the proxy channel has been signaled.
/// @param max_poll The max number of signals to poll.
/// @return true if the proxy channel has been signaled.
MSCCLPP_DEVICE_INLINE bool poll() { return proxyChan_.poll(); }
MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { return proxyChan_.poll(max_poll); }

/// Wait for the proxy channel to be signaled.
/// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative.
40 changes: 27 additions & 13 deletions include/mscclpp/semaphore_device.hpp
Original file line number Diff line number Diff line change
@@ -17,11 +17,18 @@ namespace mscclpp {
struct Host2DeviceSemaphoreDeviceHandle {
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Poll if the host has signaled.
/// @return true if the host has signaled.
MSCCLPP_DEVICE_INLINE bool poll() {
bool signaled = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
return signaled;
/// @param max_poll The max number of signals to poll.
/// @return number of signals up to max_poll that the remote device has signaled.
MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) {
if (max_poll <= 0) return 0;
uint64_t count = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) - (*expectedInboundSemaphoreId));
if (count <= 0) {
return 0;
} else {
if (max_poll < count) count = max_poll;
*expectedInboundSemaphoreId += count;
return count;
}
}

/// Wait for the host to signal.
@@ -40,11 +47,18 @@ struct Host2DeviceSemaphoreDeviceHandle {
struct SmDevice2DeviceSemaphoreDeviceHandle {
#if defined(MSCCLPP_DEVICE_COMPILE)
/// Poll if the remote device has signaled.
/// @return true if the remote device has signaled.
MSCCLPP_DEVICE_INLINE bool poll() {
bool signaled = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) > (*expectedInboundSemaphoreId));
if (signaled) (*expectedInboundSemaphoreId) += 1;
return signaled;
/// @param max_poll The max number of signals to poll.
/// @return number of signals up to max_poll that the remote device has signaled.
MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) {
if (max_poll <= 0) return 0;
uint64_t count = (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) - (*expectedInboundSemaphoreId));
if (count <= 0) {
return 0;
} else {
if (max_poll < count) count = max_poll;
*expectedInboundSemaphoreId += count;
return count;
}
}

/// Wait for the remote device to signal.
@@ -59,10 +73,10 @@ struct SmDevice2DeviceSemaphoreDeviceHandle {
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
MSCCLPP_DEVICE_INLINE void signal() {
MSCCLPP_DEVICE_INLINE void signal(const uint64_t count = 1) {
// This fence ensures that preceding writes are visible on the peer GPU before the incremented
// `outboundSemaphoreId` is visible.
semaphoreIncrement();
semaphoreIncrement(count);
atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderSeqCst);
}

@@ -90,7 +104,7 @@ struct SmDevice2DeviceSemaphoreDeviceHandle {
}

/// Increase the counter of the local semaphore.
MSCCLPP_DEVICE_INLINE void semaphoreIncrement() { *outboundSemaphoreId += 1; }
MSCCLPP_DEVICE_INLINE void semaphoreIncrement(const uint64_t count = 1) { *outboundSemaphoreId += count; }

/// Get the value of the local semaphore.
MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return *outboundSemaphoreId; }
5 changes: 3 additions & 2 deletions include/mscclpp/sm_channel_device.hpp
Original file line number Diff line number Diff line change
@@ -243,7 +243,7 @@ struct SmChannelDeviceHandle {
/// This function guarantees that all the memory operation before this function is completed before the remote
/// semaphore is signaled.
///
MSCCLPP_DEVICE_INLINE void signal() { semaphore_.signal(); }
MSCCLPP_DEVICE_INLINE void signal(uint64_t count = 1) { semaphore_.signal(count); }

/// Signal the remote semaphore.
///
@@ -267,8 +267,9 @@ struct SmChannelDeviceHandle {
MSCCLPP_DEVICE_INLINE uint64_t semaphoreGetLocal() const { return semaphore_.semaphoreGetLocal(); }

/// Check if the remote semaphore has signaled.
/// @param max_poll The max number of signals to poll.
/// @return true if the remote semaphore has signaled.
MSCCLPP_DEVICE_INLINE bool poll() { return semaphore_.poll(); }
MSCCLPP_DEVICE_INLINE uint64_t poll(const int64_t max_poll = 1) { return semaphore_.poll(max_poll); }

/// Wait for the remote semaphore to send a signal.
/// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative.