Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Binyang2014 committed Jul 27, 2023
1 parent 4b9e709 commit 59e15c8
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 71 deletions.
11 changes: 4 additions & 7 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,6 @@ class RegisteredMemory {
/// @return The size of the memory block.
size_t size();

/// Get the pitch of the memory block.
///
/// @return The pitch of the memory block.
size_t pitch();

/// Get the rank of the process that owns the memory block.
///
/// @return The rank of the process that owns the memory block.
Expand Down Expand Up @@ -384,12 +379,14 @@ class Connection {
///
/// @param dst The destination @ref RegisteredMemory.
/// @param dstOffset The offset in bytes from the start of the destination @ref RegisteredMemory.
/// @param dstPitch The pitch of the destination @ref RegisteredMemory in bytes.
/// @param src The source @ref RegisteredMemory.
/// @param srcOffset The offset in bytes from the start of the source @ref RegisteredMemory.
/// @param srcPitch The pitch of the source @ref RegisteredMemory in bytes.
/// @param width The width of the 2D region to write in bytes.
/// @param height The height of the 2D region.
virtual void write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t width, uint64_t height) = 0;
virtual void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src,
uint64_t srcOffset, uint64_t srcPitch, uint64_t width, uint64_t height) = 0;
/// Update a 8-byte value in a destination @ref RegisteredMemory and synchronize the change with the remote process.
///
/// @param dst The destination @ref RegisteredMemory.
Expand Down
4 changes: 4 additions & 0 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mscclpp/fifo.hpp>
#include <mscclpp/proxy.hpp>
#include <mscclpp/semaphore.hpp>
#include <unordered_map>

namespace mscclpp {

Expand Down Expand Up @@ -40,6 +41,8 @@ class ProxyService : public BaseProxyService {
/// @return The ID of the semaphore.
SemaphoreId addSemaphore(std::shared_ptr<Connection> connection);

void addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch);

/// Register a memory region with the proxy service.
/// @param memory The memory region to register.
/// @return The ID of the memory region.
Expand All @@ -65,6 +68,7 @@ class ProxyService : public BaseProxyService {
Communicator& communicator_;
std::vector<std::shared_ptr<Host2DeviceSemaphore>> semaphores_;
std::vector<RegisteredMemory> memories_;
std::unordered_map<SemaphoreId, std::pair<uint64_t, uint64_t>> pitches_;
Proxy proxy_;
int deviceNumaNode;

Expand Down
6 changes: 0 additions & 6 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t
std::make_shared<RegisteredMemory::Impl>(ptr, size, pimpl->bootstrap_->getRank(), transports, *pimpl));
}

MSCCLPP_API_CPP RegisteredMemory Communicator::registerMemory(void* ptr, size_t size, size_t pitchSize,
TransportFlags transports) {
return RegisteredMemory(
std::make_shared<RegisteredMemory::Impl>(ptr, size, pitchSize, pimpl->bootstrap_->getRank(), transports, *pimpl));
}

struct MemorySender : public Setuppable {
MemorySender(RegisteredMemory memory, int remoteRank, int tag)
: memory_(memory), remoteRank_(remoteRank), tag_(tag) {}
Expand Down
11 changes: 6 additions & 5 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
// npkitCollectEntryEvent(conn, NPKIT_EVENT_DMA_SEND_DATA_ENTRY, (uint32_t)size);
}

void CudaIpcConnection::write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t width, uint64_t height) {
void CudaIpcConnection::write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src,
uint64_t srcOffset, uint64_t srcPitch, uint64_t width, uint64_t height) {
validateTransport(dst, remoteTransport());
validateTransport(src, transport());

char* dstPtr = (char*)dst.data();
char* srcPtr = (char*)src.data();

MSCCLPP_CUDATHROW(cudaMemcpy2DAsync(dstPtr + dstOffset, dst.pitch(), srcPtr + srcOffset, src.pitch(), width, height,
MSCCLPP_CUDATHROW(cudaMemcpy2DAsync(dstPtr + dstOffset, dstPitch, srcPtr + srcOffset, srcPitch, width, height,
cudaMemcpyDeviceToDevice, stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, width %lu height %lu dstPitch %lu srcPitch %lu",
srcPtr + srcOffset, dstPtr + dstOffset, width, height, dst.pitch(), src.pitch());
srcPtr + srcOffset, dstPtr + dstOffset, width, height, dstPitch, srcPitch);
}

void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) {
Expand Down Expand Up @@ -141,7 +141,8 @@ void IBConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMem
// npkitCollectEntryEvent(conn, NPKIT_EVENT_IB_SEND_DATA_ENTRY, (uint32_t)size);
}

void IBConnection::write2D(RegisteredMemory, uint64_t, RegisteredMemory, uint64_t, uint64_t, uint64_t) {
void IBConnection::write2D(RegisteredMemory, uint64_t, uint64_t, RegisteredMemory, uint64_t, uint64_t, uint64_t,
uint64_t) {
throw Error("write2D is not supported", ErrorCode::InvalidUsage);
}

Expand Down
8 changes: 4 additions & 4 deletions src/include/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class CudaIpcConnection : public ConnectionBase {

void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t width,
uint64_t height) override;
void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, uint64_t srcOffset,
uint64_t srcPitch, uint64_t width, uint64_t height) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;

void flush() override;
Expand All @@ -69,8 +69,8 @@ class IBConnection : public ConnectionBase {

void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void write2D(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset, uint64_t width,
uint64_t height) override;
void write2D(RegisteredMemory dst, uint64_t dstOffset, uint64_t dstPitch, RegisteredMemory src, uint64_t srcOffset,
uint64_t srcPitch, uint64_t width, uint64_t height) override;
void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t* src, uint64_t newValue) override;

void flush() override;
Expand Down
2 changes: 0 additions & 2 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,11 @@ struct TransportInfo {
struct RegisteredMemory::Impl {
void* data;
size_t size;
size_t pitch; // for 2D
int rank;
uint64_t hostHash;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;

Impl(void* data, size_t size, size_t pitch, int rank, TransportFlags transports, Communicator::Impl& commImpl);
Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl);
Impl(const std::vector<char>& data);

Expand Down
9 changes: 7 additions & 2 deletions src/proxy_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ MSCCLPP_API_CPP SemaphoreId ProxyService::addSemaphore(std::shared_ptr<Connectio
return semaphores_.size() - 1;
}

MSCCLPP_API_CPP void ProxyService::addPitch(SemaphoreId id, std::pair<uint64_t, uint64_t> pitch) {
pitches_[id] = pitch;
}

MSCCLPP_API_CPP MemoryId ProxyService::addMemory(RegisteredMemory memory) {
memories_.push_back(memory);
return memories_.size() - 1;
Expand Down Expand Up @@ -63,8 +67,9 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
if (trigger->fields2D.multiDimensionFlag) {
semaphore->connection()->write2D(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields2D.width, trigger->fields2D.height);
std::pair<uint64_t, uint64_t>& pitch = pitches_[trigger->fields.chanId];
semaphore->connection()->write2D(dst, trigger->fields.dstOffset, pitch.first, src, trigger->fields.srcOffset,
pitch.second, trigger->fields2D.width, trigger->fields2D.height);
} else {
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields.size);
Expand Down
16 changes: 1 addition & 15 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,7 @@
namespace mscclpp {

RegisteredMemory::Impl::Impl(void* data, size_t size, int rank, TransportFlags transports, Communicator::Impl& commImpl)
: Impl(data, size, size, rank, transports, commImpl) {}

RegisteredMemory::Impl::Impl(void* data, size_t size, size_t pitch, int rank, TransportFlags transports,
Communicator::Impl& commImpl)
: data(data),
size(size),
pitch(pitch),
rank(rank),
hostHash(commImpl.rankToHash_.at(rank)),
transports(transports) {
: data(data), size(size), rank(rank), hostHash(commImpl.rankToHash_.at(rank)), transports(transports) {
if (transports.has(Transport::CudaIpc)) {
TransportInfo transportInfo;
transportInfo.transport = Transport::CudaIpc;
Expand Down Expand Up @@ -69,16 +60,13 @@ MSCCLPP_API_CPP void* RegisteredMemory::data() { return pimpl->data; }

MSCCLPP_API_CPP size_t RegisteredMemory::size() { return pimpl->size; }

MSCCLPP_API_CPP size_t RegisteredMemory::pitch() { return pimpl->pitch; }

MSCCLPP_API_CPP int RegisteredMemory::rank() { return pimpl->rank; }

MSCCLPP_API_CPP TransportFlags RegisteredMemory::transports() { return pimpl->transports; }

MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() {
std::vector<char> result;
std::copy_n(reinterpret_cast<char*>(&pimpl->size), sizeof(pimpl->size), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->pitch), sizeof(pimpl->pitch), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->rank), sizeof(pimpl->rank), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->hostHash), sizeof(pimpl->hostHash), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&pimpl->transports), sizeof(pimpl->transports), std::back_inserter(result));
Expand Down Expand Up @@ -111,8 +99,6 @@ RegisteredMemory::Impl::Impl(const std::vector<char>& serialization) {
auto it = serialization.begin();
std::copy_n(it, sizeof(this->size), reinterpret_cast<char*>(&this->size));
it += sizeof(this->size);
std::copy_n(it, sizeof(this->pitch), reinterpret_cast<char*>(&this->pitch));
it += sizeof(this->pitch);
std::copy_n(it, sizeof(this->rank), reinterpret_cast<char*>(&this->rank));
it += sizeof(this->rank);
std::copy_n(it, sizeof(this->hostHash), reinterpret_cast<char*>(&this->hostHash));
Expand Down
25 changes: 4 additions & 21 deletions test/mp_unit/communicator_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,7 @@ void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, mscc
const std::vector<int>& remoteRanks,
mscclpp::RegisteredMemory& localMemory,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemories) {
registerMemoryPairs(buff, buffSize, buffSize, transport, tag, remoteRanks, localMemory, remoteMemories);
}

// Register a local memory with pitch and receive corresponding remote memories
void CommunicatorTestBase::registerMemoryPairs(void* buff, size_t buffSize, size_t pitchSize,
mscclpp::TransportFlags transport, int tag,
const std::vector<int>& remoteRanks,
mscclpp::RegisteredMemory& localMemory,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemories) {
localMemory = communicator->registerMemory(buff, buffSize, pitchSize, transport);
localMemory = communicator->registerMemory(buff, buffSize, transport);
std::unordered_map<int, mscclpp::NonblockingFuture<mscclpp::RegisteredMemory>> futureRemoteMemories;
for (int remoteRank : remoteRanks) {
if (remoteRank != communicator->bootstrap()->getRank()) {
Expand Down Expand Up @@ -105,9 +96,7 @@ void CommunicatorTest::SetUp() {

devicePtr.resize(numBuffers);
localMemory.resize(numBuffers);
local2DMemory.resize(numBuffers);
remoteMemory.resize(numBuffers);
remote2DMemory.resize(numBuffers);

std::vector<int> remoteRanks;
for (int i = 0; i < gEnv->worldSize; i++) {
Expand All @@ -121,18 +110,11 @@ void CommunicatorTest::SetUp() {
registerMemoryPairs(devicePtr[n].get(), deviceBufferSize, mscclpp::Transport::CudaIpc | ibTransport, 0, remoteRanks,
localMemory[n], remoteMemory[n]);
}

for (size_t n = 0; n < numBuffers; n++) {
registerMemoryPairs(devicePtr[n].get(), deviceBufferSize, deviceBufferPitchSize, mscclpp::Transport::CudaIpc, 0,
remoteRanks, local2DMemory[n], remote2DMemory[n]);
}
}

void CommunicatorTest::TearDown() {
remoteMemory.clear();
remote2DMemory.clear();
localMemory.clear();
local2DMemory.clear();
devicePtr.clear();
CommunicatorTestBase::TearDown();
}
Expand Down Expand Up @@ -168,8 +150,9 @@ void CommunicatorTest::writeTileToRemote(size_t rowIndex, size_t colIndex, size_
for (int i = 0; i < gEnv->worldSize; i++) {
if (i != gEnv->rank) {
auto& conn = connections.at(i);
auto& peerMemory = remote2DMemory[n].at(i);
conn->write2D(peerMemory, offset, local2DMemory[n], offset, width * sizeof(int), height);
auto& peerMemory = remoteMemory[n].at(i);
conn->write2D(peerMemory, offset, deviceBufferPitchSize, localMemory[n], offset, deviceBufferPitchSize,
width * sizeof(int), height);
conn->flush();
}
}
Expand Down
6 changes: 0 additions & 6 deletions test/mp_unit/mp_unit_tests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ class CommunicatorTestBase : public MultiProcessTest {
void registerMemoryPairs(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag,
const std::vector<int>& remoteRanks, mscclpp::RegisteredMemory& localMemory,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemories);
// Register a local memory with pitch and receive corresponding remote memories
void registerMemoryPairs(void* buff, size_t buffSize, size_t pitch, mscclpp::TransportFlags transport, int tag,
const std::vector<int>& remoteRanks, mscclpp::RegisteredMemory& localMemory,
std::unordered_map<int, mscclpp::RegisteredMemory>& remoteMemories);
// Register a local memory an receive one corresponding remote memory
void registerMemoryPair(void* buff, size_t buffSize, mscclpp::TransportFlags transport, int tag, int remoteRank,
mscclpp::RegisteredMemory& localMemory, mscclpp::RegisteredMemory& remoteMemory);
Expand All @@ -128,9 +124,7 @@ class CommunicatorTest : public CommunicatorTestBase {
const int deviceBufferPitchSize = 512;
std::vector<std::shared_ptr<int>> devicePtr;
std::vector<mscclpp::RegisteredMemory> localMemory;
std::vector<mscclpp::RegisteredMemory> local2DMemory;
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remoteMemory;
std::vector<std::unordered_map<int, mscclpp::RegisteredMemory>> remote2DMemory;
};

class ProxyChannelOneToOneTest : public CommunicatorTestBase {
Expand Down
7 changes: 4 additions & 3 deletions test/mp_unit/proxy_channel_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ void ProxyChannelOneToOneTest::setupMeshConnections(

void ProxyChannelOneToOneTest::setupMeshConnections(
std::vector<DeviceHandle<mscclpp::SimpleProxyChannel>>& proxyChannels, bool useIbOnly, void* sendBuff,
size_t sendBuffBytes, size_t pitchSize, void* recvBuff, size_t recvBuffBytes) {
size_t sendBuffBytes, size_t pitch, void* recvBuff, size_t recvBuffBytes) {
const int rank = communicator->bootstrap()->getRank();
const int worldSize = communicator->bootstrap()->getNranks();
const bool isInPlace = (recvBuff == nullptr);
mscclpp::TransportFlags transport = (useIbOnly) ? ibTransport : (mscclpp::Transport::CudaIpc | ibTransport);

mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, pitchSize, transport);
mscclpp::RegisteredMemory sendBufRegMem = communicator->registerMemory(sendBuff, sendBuffBytes, transport);
mscclpp::RegisteredMemory recvBufRegMem;
if (!isInPlace) {
recvBufRegMem = communicator->registerMemory(recvBuff, recvBuffBytes, pitchSize, transport);
recvBufRegMem = communicator->registerMemory(recvBuff, recvBuffBytes, transport);
}

for (int r = 0; r < worldSize; r++) {
Expand All @@ -59,6 +59,7 @@ void ProxyChannelOneToOneTest::setupMeshConnections(
communicator->setup();

mscclpp::SemaphoreId cid = channelService->addSemaphore(conn);
channelService->addPitch(cid, std::pair<size_t, size_t>(pitch, pitch));
communicator->setup();

proxyChannels.emplace_back(mscclpp::deviceHandle(
Expand Down

0 comments on commit 59e15c8

Please sign in to comment.