Skip to content

Commit

Permalink
adding ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Saeed Maleki committed Oct 23, 2023
1 parent c3d9bce commit 8e02b03
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 7 deletions.
6 changes: 1 addition & 5 deletions python/benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@ def human_readable_size(size, decimal_places=1):
return f"{size:.{decimal_places}f} {unit}"

def check_correctness(memory, func):
print("here0", flush=True)
rand_gen = cp.random.default_rng(seed=MPI.COMM_WORLD.rank)
memory[:] = rand_gen.random(memory.shape)
print("here1", flush=True)
cp.cuda.runtime.deviceSynchronize()
func(0)
cp.cuda.runtime.deviceSynchronize()
print("here2", flush=True)
expected = cp.zeros_like(memory)
for i in range(MPI.COMM_WORLD.size):
rand_gen = cp.random.default_rng(seed=i)
expected += rand_gen.random(memory.shape)
print("here3", flush=True)
return cp.allclose(memory, expected)

def bench_time(niter: int, func):
Expand Down Expand Up @@ -99,7 +95,7 @@ def run_benchmark(mscclpp_op: MscclppOp, nccl_op: NcclOp, table: PrettyTable, ni
if MPI.COMM_WORLD.rank == 0:
# Set table headers
table = PrettyTable()
table.field_names = ["Size", "Time (us)", "AlgBW (GB/s)", "NCCL Time (us)", "Correctness", "NCCL AlgBW (GB/s)", "Correctness", "Speed Up"]
table.field_names = ["Size", "Time (us)", "AlgBW (GB/s)", "Correctness", "NCCL Time (us)", "NCCL AlgBW (GB/s)", "NCCL Correctness", "Speed Up"]

for i in range(10,28):
run_benchmark(mscclpp_op, nccl_op, table, 100, 2**i)
Expand Down
55 changes: 55 additions & 0 deletions python/benchmark/mscclpp_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import cupy as cp
from test.mscclpp_group import MscclppGroup
from test.utils import KernelBuilder, pack
from mscclpp import Transport
from mpi4py import MPI
import netifaces as ni

class MscclppOp():
def __init__(self):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
network_interface = "eth0"
my_ip = ni.ifaddresses(network_interface)[ni.AF_INET][0]["addr"]
root_ip = comm.bcast(my_ip, root=0)
ifIpPortTrio = network_interface + ":" + root_ip + ":50000" # some random port

self.group = MscclppGroup(interfaceIpPortTrio=ifIpPortTrio, rank=rank, size=size)
self.group.barrier()


def make_callback1(self, memory):
self.memory = memory
remote_nghrs = list(range(self.group.nranks))
remote_nghrs.remove(self.group.my_rank)

self.group.barrier()
# create a connection for each remote neighbor
self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc)
type_str = ""
if memory.dtype == cp.float16:
type_str = "__half"
elif memory.dtype == cp.float32:
type_str = "float"
elif memory.dtype == cp.int32:
type_str = "int"
else:
raise RuntimeError("Unknown data type")

# create a sm_channel for each remote neighbor
self.sm_channels = self.group.make_sm_channels(self.memory, self.connections)
file_dir = os.path.dirname(os.path.abspath(__file__))
self.kernel = KernelBuilder(file="allreduce1.cu", kernel_name="allreduce1", file_dir=file_dir, macro_dict={"TYPE": type_str}).get_compiled_kernel()
self.params = b""
self.device_handles = []
for rank in range(self.group.nranks):
if rank != self.group.my_rank:
self.device_handles.append(self.sm_channels[rank].device_handle().raw)
self.params += pack(cp.asarray(memoryview(b"".join(self.device_handles)), dtype=cp.uint8), self.memory, self.group.my_rank, self.group.nranks, self.memory.size)

def _make_call(stream_ptr):
self.kernel.launch_kernel(self.params, 24, 1024, 0, stream_ptr)

return _make_call
21 changes: 21 additions & 0 deletions python/benchmark/nccl_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import cupy.cuda.nccl as nccl
from mpi4py import MPI

class NcclOp:
def __init__(self):
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# Create a NCCL unique ID and communicator
if rank == 0:
uid = nccl.get_unique_id()
else:
uid = None
uid = comm.bcast(uid, root=0)
self.nccl_comm = nccl.NcclCommunicator(size, uid, rank)

def make_callback(self, memory):
def _make_callback(stream_ptr):
self.nccl_comm.allReduce(memory.data.ptr, memory.data.ptr, memory.size, nccl.NCCL_FLOAT32, nccl.NCCL_SUM, stream_ptr)
return _make_callback
4 changes: 2 additions & 2 deletions python/test/mscclpp_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@

class MscclppGroup:
def __init__(self, mpi_group: MpiGroup = None, interfaceIpPortTrio : str = "", rank : int = None, size : int = None):
print("QQQQQQ", interfaceIpPortTrio)
if interfaceIpPortTrio == "":
self.bootstrap = TcpBootstrap.create(rank, size)
self.bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
uniq_id = None
if mpi_group.comm.rank == 0:
# similar to NCCL's unique id
Expand All @@ -43,6 +42,7 @@ def __init__(self, mpi_group: MpiGroup = None, interfaceIpPortTrio : str = "", r
self.bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
self.bootstrap.initialize(interfaceIpPortTrio)
elif not interfaceIpPortTrio == "":
assert rank >= 0 and size >= 1
self.bootstrap = TcpBootstrap.create(rank, size)
self.bootstrap.initialize(interfaceIpPortTrio)
else:
Expand Down

0 comments on commit 8e02b03

Please sign in to comment.