Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Allgather with proxy channel hangs at H2D cudaMemcpyAsync #440

Open
cubele opened this issue Jan 5, 2025 · 0 comments
Open

[Bug] Allgather with proxy channel hangs at H2D cudaMemcpyAsync #440

cubele opened this issue Jan 5, 2025 · 0 comments

Comments

@cubele
Copy link

cubele commented Jan 5, 2025

We implemented a simple intra-node allgather algorithm using mscclpplang with proxychannels. However, when running the generated json algorithm file using the nccl interface by setting MSCCL_NCCL_PLAN_DIR, mscclpp reports the following error:

/include/mscclpp/semaphore_device.hpp:30: void mscclpp::Host2DeviceSemaphoreDeviceHandle::wait(signed long): block: [0,0,0], thread: [1,0,0] Assertion (atomicLoad(inboundSemaphoreId, memoryOrderAcquire) < (*expectedInboundSemaphoreId)) failed.

We suspect this is the same error with #394 and #285 where the H2D cudaMemcpyAsync hangs and the semaphore is never signaled. However, there is no working solution to this error. Do you have any thoughts on this?

The algorithm json file is generated using the following mscclpplang program:

import argparse
from msccl.language import *
from msccl.topologies import *
from msccl.language.collectives import AllGather

def ring_allgather(gpus, instances, inplace=False):
    size = gpus
    topology = fully_connected(size)
    collective = AllGather(size, 1, inplace)

    with MSCCLPPProgram(
        f"allgather_ring_proxy_n={size}_i={instances}_inp={inplace}",
        topology,
        collective,
        instances,
        protocol="Simple",
        replication_policy=ReplicationPolicy.interleaved,
    ):
        # Chunk i
        for i in range(size):
            for step in range(size - 1):
                send_rank = (i + step) % size
                recv_rank = (i + step + 1) % size

                c = chunk(send_rank, Buffer.input, 0) if step == 0 else chunk(send_rank, Buffer.output, i)
                c.put(
                    recv_rank,
                    Buffer.output,
                    i,
                    sendtb=0,
                    chan_type=ChannelType.proxy,
                )
                c.signal(recv_rank, Buffer.output, i, sendtb=0, chan_type=ChannelType.proxy)
                c.flush(recv_rank, Buffer.output, i, sendtb=0, chan_type=ChannelType.proxy)
                cr = chunk(recv_rank, Buffer.output, i)
                cr.wait(send_rank, Buffer.input, 0, recvtb=0, chan_type=ChannelType.proxy)

        if not inplace:
            for i in range(size):
                c = chunk(i, Buffer.input, 0)
                c.copy(i, Buffer.output, i, sendtb=0)

        Json()
        Check()

parser = argparse.ArgumentParser()
parser.add_argument("num_gpus", type=int, help="number of gpus")
parser.add_argument("instances", type=int, help="number of instances")
parser.add_argument("--inplace", action="store_true", help="inplace reducescatter")
args = parser.parse_args()

ring_allgather(args.num_gpus, args.instances, args.inplace)

And Allgather is called using the following example code provided by nccl on a non-default stream:

#include <stdio.h>
#include "cuda_runtime.h"
#include "nccl.h"
#include "mpi.h"
#include <unistd.h>
#include <stdint.h>
#include <stdlib.h>

#define MPICHECK(cmd) do {                          \
  int e = cmd;                                      \
  if( e != MPI_SUCCESS ) {                          \
    printf("Failed: MPI error %s:%d '%d'\n",        \
        __FILE__,__LINE__, e);   \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

#define CUDACHECK(cmd) do {                         \
  cudaError_t e = cmd;                              \
  if( e != cudaSuccess ) {                          \
    printf("Failed: Cuda error %s:%d '%s'\n",             \
        __FILE__,__LINE__,cudaGetErrorString(e));   \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

#define NCCLCHECK(cmd) do {                         \
  ncclResult_t r = cmd;                             \
  if (r!= ncclSuccess) {                            \
    printf("Failed, NCCL error %s:%d '%s'\n",             \
        __FILE__,__LINE__,ncclGetErrorString(r));   \
    exit(EXIT_FAILURE);                             \
  }                                                 \
} while(0)

static uint64_t getHostHash(const char* string) {
  // Based on DJB2a, result = result * 33 ^ char
  uint64_t result = 5381;
  for (int c = 0; string[c] != '\0'; c++){
    result = ((result << 5) + result) ^ string[c];
  }
  return result;
}

static void getHostName(char* hostname, int maxlen) {
  gethostname(hostname, maxlen);
  for (int i=0; i< maxlen; i++) {
    if (hostname[i] == '.') {
        hostname[i] = '\0';
        return;
    }
  }
}

int main(int argc, char* argv[])
{
  int size = 32*1024*1024; // Size of the buffer each rank will send
  int myRank, nRanks, localRank = 0;

  //initializing MPI
  MPICHECK(MPI_Init(&argc, &argv));
  MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
  MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));

  //calculating localRank based on hostname which is used in selecting a GPU
  uint64_t hostHashs[nRanks];
  char hostname[1024];
  getHostName(hostname, 1024);
  hostHashs[myRank] = getHostHash(hostname);
  MPICHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
  for (int p=0; p<nRanks; p++) {
     if (p == myRank) break;
     if (hostHashs[p] == hostHashs[myRank]) localRank++;
  }

  ncclUniqueId id;
  ncclComm_t comm;
  float *sendbuff, *recvbuff;
  cudaStream_t s;

  //get NCCL unique ID at rank 0 and broadcast it to all others
  if (myRank == 0) ncclGetUniqueId(&id);
  MPICHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));

  //picking a GPU based on localRank, allocate device buffers
  CUDACHECK(cudaSetDevice(localRank));
  CUDACHECK(cudaMalloc(&sendbuff, size * sizeof(float)));
  CUDACHECK(cudaMalloc(&recvbuff, size * nRanks * sizeof(float))); // Allocate for allgather
  CUDACHECK(cudaStreamCreate(&s));

  //initializing NCCL
  NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));

  NCCLCHECK(ncclAllGather((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, comm, s));

  CUDACHECK(cudaStreamSynchronize(s));

  //free device buffers
  CUDACHECK(cudaFree(sendbuff));
  CUDACHECK(cudaFree(recvbuff));

  //finalizing NCCL
  ncclCommDestroy(comm);

  //finalizing MPI
  MPICHECK(MPI_Finalize());

  printf("[MPI Rank %d] Success \n", myRank);
  return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant