Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Oct 16, 2024
1 parent 7a9d657 commit b826c21
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
17 changes: 7 additions & 10 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ template <>
__forceinline__ __device__ __half2 clip(__half2 val) {
val.x = __hmax(val.x, bit_cast<__half, unsigned short>(0xfbff));
val.x = __hmin(val.x, bit_cast<__half, unsigned short>(0x7bff));

val.y = __hmax(val.y, bit_cast<__half, unsigned short>(0xfbff));
val.y = __hmin(val.y, bit_cast<__half, unsigned short>(0x7bff));
return val;
Expand Down Expand Up @@ -242,19 +241,19 @@ __global__ void __launch_bounds__(1024, 1)
size_t nelems, uint32_t flag) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;

if (sizeof(T) == 2)
nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
else
nelems = nelems / (sizeof(int) / sizeof(T));
nelems = nelems / (sizeof(int) / sizeof(T));

const int nPeers = nRanksPerNode - 1;
const size_t nPkts = nelems/2;
const size_t nPkts = nelems / 2;

int nelemsPerRank = nelems / worldSize;
if ((nelemsPerRank % 2)) nelemsPerRank = (nelemsPerRank * sizeof(T) + sizeof(T)) / sizeof(T);

const int nPktsPerRank = nelemsPerRank/2;
const int nPktsPerRank = nelemsPerRank / 2;
// thread block & channel info
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
Expand Down Expand Up @@ -286,9 +285,7 @@ __global__ void __launch_bounds__(1024, 1)
for (int index = 0; index < NPEERS; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
//uint32_t val = dstPkt[idx].read(flag, -1);
uint2 val = dstPkt[idx].read(flag);
//data = add_vectors<T>(val, data);
data.x = add_vectors<T>(val.x, data.x);
data.y = add_vectors<T>(val.y, data.y);
}
Expand Down
2 changes: 0 additions & 2 deletions include/mscclpp/packet_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ union alignas(16) LL16Packet {
#else // !defined(MSCCLPP_DEVICE_CUDA)
uint4 reg = make_uint4(val1, flag, val2, flag);
ulonglong2* p = reinterpret_cast<ulonglong2*>(&reg);
/*atomicStore(&(raw_.x), p->x, memoryOrderRelaxed);
atomicStore(&(raw_.y), p->y, memoryOrderRelaxed);*/
__builtin_nontemporal_store(p->x, &(raw_.x));
__builtin_nontemporal_store(p->y, &(raw_.y));
#endif
Expand Down

0 comments on commit b826c21

Please sign in to comment.