From 148681b4bcfb7d91b2e2de788f24b9417f1d1b5b Mon Sep 17 00:00:00 2001 From: Saeed Maleki <30272783+saeedmaleki@users.noreply.github.com> Date: Fri, 13 Oct 2023 01:39:43 -0700 Subject: [PATCH] Fix a pytest bug (#196) --- python/test/test_mscclpp.py | 6 +++++- src/connection.cc | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 6674f4ea0..3af1580a4 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -123,7 +123,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int) @parametrize_mpi_groups(2, 4, 8, 16) @pytest.mark.parametrize("transport", ["IB", "NVLink"]) -@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]]) @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str): # this test starts with a random tensor on rank 0 and rotates it all the way through all ranks @@ -139,6 +139,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, memory_expected = memory.copy() else: memory = xp.zeros(nelem, dtype=xp.float32) + if device == "cuda": + cp.cuda.runtime.deviceSynchronize() signal_memory = xp.zeros(1, dtype=xp.int64) all_reg_memories = group.register_tensor_with_connections(memory, connections) @@ -156,6 +158,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, connections[next_rank].flush() if group.my_rank == 0: memory[:] = 0 + if device == "cuda": + cp.cuda.runtime.deviceSynchronize() connections[next_rank].update_and_sync( all_signal_memories[next_rank], 0, dummy_memory_on_cpu.ctypes.data, signal_val ) diff --git a/src/connection.cc b/src/connection.cc index ae9e760fc..284e542ef 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -163,8 +163,8 @@ void IBConnection::flush(int64_t timeoutUsec) { } auto elapsed = timer.elapsed(); - if ((timeoutUsec >= 0) && (elapsed * 1e3 > timeoutUsec)) { - throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e3) + " seconds. Expected " + + if ((timeoutUsec >= 0) && (elapsed > timeoutUsec)) { + throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " + std::to_string(numSignaledSends) + " signals", ErrorCode::InternalError); }