Skip to content

Commit

Permalink
Merge branch 'main' into saemal/debugging_info
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Oct 13, 2023
2 parents 8abf822 + 148681b commit 706f91c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int)

@parametrize_mpi_groups(2, 4, 8, 16)
@pytest.mark.parametrize("transport", ["IB", "NVLink"])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]])
@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str):
# this test starts with a random tensor on rank 0 and rotates it all the way through all ranks
Expand All @@ -139,6 +139,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
memory_expected = memory.copy()
else:
memory = xp.zeros(nelem, dtype=xp.float32)
if device == "cuda":
cp.cuda.runtime.deviceSynchronize()

signal_memory = xp.zeros(1, dtype=xp.int64)
all_reg_memories = group.register_tensor_with_connections(memory, connections)
Expand All @@ -156,6 +158,8 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport,
connections[next_rank].flush()
if group.my_rank == 0:
memory[:] = 0
if device == "cuda":
cp.cuda.runtime.deviceSynchronize()
connections[next_rank].update_and_sync(
all_signal_memories[next_rank], 0, dummy_memory_on_cpu.ctypes.data, signal_val
)
Expand Down
4 changes: 2 additions & 2 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ void IBConnection::flush(int64_t timeoutUsec) {
}

auto elapsed = timer.elapsed();
if ((timeoutUsec >= 0) && (elapsed * 1e3 > timeoutUsec)) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e3) + " seconds. Expected " +
if ((timeoutUsec >= 0) && (elapsed > timeoutUsec)) {
throw Error("pollCq is stuck: waited for " + std::to_string(elapsed / 1e6) + " seconds. Expected " +
std::to_string(numSignaledSends) + " signals",
ErrorCode::InternalError);
}
Expand Down

0 comments on commit 706f91c

Please sign in to comment.