Skip to content

Commit

Permalink
More bindings and better test parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
olsaarik committed Oct 18, 2023
1 parent e90cace commit b6108df
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
3 changes: 3 additions & 0 deletions python/mscclpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
Transport,
TransportFlags,
version,
get_ib_device_count,
get_ib_device_name,
get_ib_transport_by_device_name,
)

__version__ = version()
Expand Down
4 changes: 4 additions & 0 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ void register_core(nb::module_& m) {
.def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def("remote_rank_of", &Communicator::remoteRankOf)
.def("tag_of", &Communicator::tagOf);

m.def("get_ib_device_count", &getIBDeviceCount);
m.def("get_ib_device_name", &getIBDeviceName, nb::arg("ib_transport"));
m.def("get_ib_transport_by_device_name", &getIBTransportByDeviceName, nb::arg("ib_device_name"));
}

NB_MODULE(_mscclpp, m) {
Expand Down
30 changes: 22 additions & 8 deletions python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,27 @@
import netifaces as ni
import pytest

from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport, get_ib_device_count
from ._cpp import _ext
from .mscclpp_group import MscclppGroup
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
from .utils import KernelBuilder, pack

ethernet_interface_name = "eth0"

skipif_ib = pytest.mark.skipif(get_ib_device_count() == 0, reason="no IB device")

def parametrize_transport(*transports: list):
def decorator(func):
params = []
for transport in transports:
if transport == "IB":
params.append(pytest.param(transport, marks=skipif_ib))
else:
params.append(transport)
return pytest.mark.parametrize("transport", params)(func)
return decorator


def all_ranks_on_the_same_node(mpi_group: MpiGroup):
if (ethernet_interface_name in ni.interfaces()) is False:
Expand Down Expand Up @@ -81,13 +94,13 @@ def create_and_connect(mpi_group: MpiGroup, transport: str):


@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
def test_group_with_connections(mpi_group: MpiGroup, transport: str):
create_and_connect(mpi_group, transport)


@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int):
group, connections = create_and_connect(mpi_group, transport)
Expand Down Expand Up @@ -122,7 +135,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)


@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
Expand Down Expand Up @@ -174,6 +187,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,


@parametrize_mpi_groups(2, 4, 8, 16)
@skipif_ib
def test_h2h_semaphores(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "IB")

Expand Down Expand Up @@ -262,7 +276,7 @@ def __call__(self):


@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@parametrize_transport("NVLink", "IB")
def test_h2d_semaphores(mpi_group: MpiGroup, transport: str):
def signal(semaphores):
for rank in semaphores:
Expand Down Expand Up @@ -295,7 +309,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup):


@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("nelem", [2**i for i in [10]])
@pytest.mark.parametrize("use_packet", [False, True])
def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool):
group, connections = create_and_connect(mpi_group, "NVLink")
Expand Down Expand Up @@ -344,7 +358,7 @@ def test_fifo(

@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@parametrize_transport("IB", "NVLink")
def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):
group, connections = create_and_connect(mpi_group, transport)

Expand Down Expand Up @@ -393,7 +407,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str):

@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("transport", ["NVLink", "IB"])
@parametrize_transport("NVLink", "IB")
@pytest.mark.parametrize("use_packet", [False, True])
def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool):
group, connections = create_and_connect(mpi_group, transport)
Expand Down

0 comments on commit b6108df

Please sign in to comment.