CUDA 经典问题:前缀和 | CUDA

本文最后更新于:2022年1月31日

问题

对于数组 a,其前缀和为数组 b,a 和 b 的长度均为 n。对于任意 i < n 都满足 b[i] = a[0] + a[1] + ... + a[i]

基本思路

前缀和的思路如下:

  1. 将整个数据分成几个部分,每个部分分别计算前缀和,存入数组 output 中,然后将每个部分中最大的值存入一个数组 part 中
  2. 对上述数组 part 求前缀和
  3. 将 part 中的元素分别加到 output 中

代码实现

Baseline

__global__ void ScanPart(int *input, int *part, int *output, int n, int part_num) {
  for (int part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    int part_begin = part_i * blockDim.x;
    int part_end = min((part_i + 1) * blockDim.x, n);
    if (threadIdx.x == 0) {
      int acc = 0;
      for (int i = part_begin; i < part_end; ++i) {
        acc += input[i];
        output[i] = acc;
      }
      part[part_i] = acc;
    }
  }
}

__global__ void ScanPartSum(int *part, int part_num) {
  int acc = 0;
  for (int i = 0; i < part_num; ++i) {
    acc += part[i];
    part[i] = acc;
  }
}

__global__ void AddPartSum(int *part, int *output, int n, int part_num) {
  for (int part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    if (part_i == 0) {
      continue;
    }
    int tid = part_i * blockDim.x + threadIdx.x;
    if (tid < n) {
      output[tid] += part[part_i - 1];
    }
  }
}

void PrefixSum(int *input, int *part, int *output, int n) {
  int part_num = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
  int block_num = std::min<int>(part_num, 128);

  ScanPart<<<block_num, BLOCK_SIZE>>>(input, part, output, n, part_num);
  ScanPartSum<<<1, 1>>>(part, part_num);
  AddPartSum<<<block_num, BLOCK_SIZE>>>(part, output, n, part_num);
}

3482 us

使用共享内存

首先使用共享内存优化下数据的读取和写入:

__device__ void ScanBlock(int *shm) {
  if (threadIdx.x == 0) {
    int acc = 0;
    for (int i = 0; i < blockDim.x; ++i) {
      acc += shm[i];
      shm[i] = acc;
    }
  }
}

__global__ void ScanPart(int *input, int *part, int *output, int n, int part_num) {
  __shared__ int shm[BLOCK_SIZE];
  for (int part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    shm[threadIdx.x] = tid < n ? input[tid] : 0;
    __syncthreads();
    ScanBlock(shm);
    __syncthreads();
    if (tid < n) {
      output[tid] = shm[threadIdx.x];
    }
    if (threadIdx.x == blockDim.x - 1) {
      part[part_i] = shm[threadIdx.x];
    }
  }
}

3482 us -> 610 us (-82%)

拆分至线程束级别

前面的实现将整个数组的 scan 拆分成每个 block 的 scan,这里还可以进行进一步的拆分:将 block 的 scan 拆分成 warp 的 scan。

__device__ void ScanWarp(int *shm_data, int lane_id) {
  if (lane_id == 0) {
    int acc = 0;
    for (int i = 0; i < 32; ++i) {
      acc += shm_data[i];
      shm_data[i] = acc;
    }
  }
}

__device__ void ScanBlock(int *shm_data) {
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
  __shared__ int warp_sum[32]; // blockDim.x(1024) / warp_size(32) = 32

  // 每个 warp 内部做 scan
  ScanWarp(shm_data, lane_id);
  __syncthreads();

  // 将每个 warp 的和存储到 warp_sum 中
  // lane_id 为 31 的线程对应的共享内存槽位中存放的是其 warp 的和
  if (lane_id == 31) {
    warp_sum[warp_id] = *shm_data;
  }
  __syncthreads();

  // 启动一个单独的 warp 对 warp_sum 进行 scan
  if (warp_id == 0) {
    ScanWarp(warp_sum, lane_id);
  }
  __syncthreads();

  // 每个 warp 将最终结果加上上一个 warp 对应的 warp_sum
  if (warp_id > 0) {
    *shm_data += warp_sum[warp_id - 1];
  }
  __syncthreads();
}

__global__ void ScanPart(int *input, int *part, int *output, int n, int part_num) {
  // 这里额外申请了 32 个 int 的空间是给 ScanBlock 中的 warp_sum 预留的
  __shared__ int shm[32 + BLOCK_SIZE];
  for (int part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    shm[32 + threadIdx.x] = tid < n ? input[tid] : 0;
    __syncthreads();
    ScanBlock(32 + shm + threadIdx.x);
    __syncthreads();
    if (tid < n) {
      output[tid] = shm[32 + threadIdx.x];
    }
    if (threadIdx.x == blockDim.x - 1) {
      part[part_i] = shm[32 + threadIdx.x];
    }
  }
}

610 us -> 351 us (-42%)

优化线程束级别 scan

为了方便解释算法,这里假设对 16 个数做 scan,如下所示:

横向的 16 个点代表 16 个数,时间轴从上往下,每个入度为 2 的节点会做加法,并将结果广播到其输出节点,对于 32 个数的代码如下:

__device__ void ScanWarp(int *shm_data, int lane_id) {
  volatile int *vshm_data = shm_data;
  if (lane_id >= 1) {
    vshm_data[0] += vshm_data[-1];
  }
  __syncwarp();
  if (lane_id >= 2) {
    vshm_data[0] += vshm_data[-2];
  }
  __syncwarp();
  if (lane_id >= 4) {
    vshm_data[0] += vshm_data[-4];
  }
  __syncwarp();
  if (lane_id >= 8) {
    vshm_data[0] += vshm_data[-8];
  }
  __syncwarp();
  if (lane_id >= 16) {
    vshm_data[0] += vshm_data[-16];
  }
  __syncwarp();
}

__device__ void ScanBlock(int *shm_data) {
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
  __shared__ int warp_sum[32];

  ScanWarp(shm_data, lane_id);
  __syncthreads();

  if (lane_id == 31) {
    warp_sum[warp_id] = *shm_data;
  }
  __syncthreads();

  // 这里与上一节不同,原因在于 ScanWarp 的用法与原来不同了,ScanWarp 要求每个线程输入的 shm_data 必须是不同的
  // 因为每个线程在 ScanWarp 中都是用 vshm_data[0] 来访问各自的数据
  // 前面那次 ScanWarp 的调用相较上一节没有改动是因为每个线程的 ScanBlock 的入参本身就是不同的
  if (warp_id == 0) {
    ScanWarp(warp_sum + lane_id, lane_id);
  }
  __syncthreads();

  if (warp_id > 0) {
    *shm_data += warp_sum[warp_id - 1];
  }
  __syncthreads();
}

351 us -> 193 us (-45%)

Zero Padding

如果要更进一步消除 ScanWarp 中的条件分,warp 中所有线程都执行同样的操作,这就意味着之前不符合条件的线程会访问越界,需要做 zero padding 使其不越界:每个 warp 需要一个 16 大小的 zero padding 才能避免 ScanWarp 在没有分支的情况下不越界。

之前需要申请共享内存的大小为 BLOCK_SIZE + 32,这里多出的 32 前面也有解释(用于存放每个 warp 的和),所以之前申请共享内存的大小也可以表示为 (warp_num + 1) * 32。由于每个 warp 额外需要 16 大小的共享内存,所以最终需要申请的共享内存大小为 (warp_num + 1) * (16 + 32)。

这里需要做两件事情:

  • 申请共享内存时多申请 zero padding 的部分:
  • 补 0 以消除 ScanWarp 中的条件分支
__device__ void ScanWarp(int *shm_data) {
  volatile int *vshm_data = shm_data;
  vshm_data[0] += vshm_data[-1];
  vshm_data[0] += vshm_data[-2];
  vshm_data[0] += vshm_data[-4];
  vshm_data[0] += vshm_data[-8];
  vshm_data[0] += vshm_data[-16];
}

__device__ void ScanBlock(int *shm_data) {
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
  extern __shared__ int warp_sum[];

  ScanWarp(shm_data);
  __syncthreads();

  if (lane_id == 31) {
    warp_sum[16 + warp_id] = *shm_data;
  }
  __syncthreads();

  if (warp_id == 0) {
    ScanWarp(16 + warp_sum + lane_id);
  }
  __syncthreads();

  if (warp_id > 0) {
    *shm_data += warp_sum[16 + warp_id - 1];
  }
  __syncthreads();
}

__global__ void ScanPart(int *input, int *part, int *output, int n, int part_num) {
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
  extern __shared__ int shm[];

  if (threadIdx.x < 16) {
    shm[threadIdx.x] = 0;
  }
  if (lane_id < 16) {
    shm[(warp_id + 1) * (16 + 32) + lane_id] = 0;
  }
  __syncthreads();

  for (int part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    int tid = blockDim.x * blockIdx.x + threadIdx.x;
    int *myshm = shm + (warp_id + 1) * (16 + 32) + (16 + lane_id);
    *myshm = tid < n ? input[tid] : 0;
    __syncthreads();
    ScanBlock(myshm);
    __syncthreads();
    if (tid < n) {
      output[tid] = *myshm;
    }
    if (threadIdx.x == blockDim.x - 1) {
      part[part_i] = *myshm;
    }
  }
}

void PrefixSum(int *input, int *part, int *output, int n) {
  int part_num = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
  int block_num = std::min<int>(part_num, 128);
  int warp_num = BLOCK_SIZE / 32;
   
  int shm_size = (warp_num + 1) * (16 + 32) * sizeof(int);
  ScanPart<<<block_num, BLOCK_SIZE, shm_size>>>(input, part, output, n, part_num);
  ScanPartSum<<<1, 1>>>(part, part_num);
  AddPartSum<<<block_num, BLOCK_SIZE>>>(part, output, n, part_num);
}

193 us -> 192 us (-1%)

这一节的优化看似不大,主要是被瓶颈掩盖了。

递归

当前瓶颈在于,ScanPartSum 是由一个线程去做的,这块可以递归地做:


void PrefixSum(int *input, int *part, int *output, int n) {
  int part_num = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
  int block_num = std::min<int>(part_num, 128);
  int warp_num = BLOCK_SIZE / 32;
    
  int shm_size = (warp_num + 1) * (16 + 32) * sizeof(int);
  ScanPart<<<block_num, BLOCK_SIZE, shm_size>>>(input, part, output, n, part_num);
  if (part_num >= 2) {
    PrefixSum(part, part + part_num, part, part_num);
    AddPartSum<<<block_num, BLOCK_SIZE>>>(part, output, n, part_num);
  }
}

192 us -> 162 us (-16%)

使用 Warp Shuffle

__device__ int ScanWarp(int val) {
  int lane_id = threadIdx.x % 32;
  int tmp = __shfl_up_sync(0xffffffff, val, 1);
  if (lane_id >= 1) {
    val += tmp;
  }
  tmp = __shfl_up_sync(0xffffffff, val, 2);
  if (lane_id >= 2) {
    val += tmp;
  }
  tmp = __shfl_up_sync(0xffffffff, val, 4);
  if (lane_id >= 4) {
    val += tmp;
  }
  tmp = __shfl_up_sync(0xffffffff, val, 8);
  if (lane_id >= 8) {
    val += tmp;
  }
  tmp = __shfl_up_sync(0xffffffff, val, 16);
  if (lane_id >= 16) {
    val += tmp;
  }
  return val;
}

__device__ __forceinline__ int ScanBlock(int val) {
  int warp_id = threadIdx.x / 32;
  int lane_id = threadIdx.x % 32;
  extern __shared__ int warp_sum[];

  val = ScanWarp(val);
  __syncthreads();

  if (lane_id == 31) {
    warp_sum[warp_id] = val;
  }
  __syncthreads();

  if (warp_id == 0) {
    warp_sum[lane_id] = ScanWarp(warp_sum[lane_id]);
  }
  __syncthreads();

  if (warp_id > 0) {
    val += warp_sum[warp_id - 1];
  }
  __syncthreads();
  return val;
}

__global__ void ScanPart(int *input, int *part, int *output, int n, int part_num) {
  for (size_t part_i = blockIdx.x; part_i < part_num; part_i += gridDim.x) {
    size_t tid = part_i * blockDim.x + threadIdx.x;
    int32_t val = tid < n ? input[tid] : 0;
    val = ScanBlock(val);
    __syncthreads();
    if (tid < n) {
      output[tid] = val;
    }
    if (threadIdx.x == blockDim.x - 1) {
      part[part_i] = val;
    }
  }
}

void PrefixSum(int *input, int *part, int *output, int n) {
  int part_num = (n + BLOCK_SIZE - 1) / BLOCK_SIZE;
  int block_num = std::min<int>(part_num, 128);

  int shm_size = 32 * sizeof(int);
  ScanPart<<<block_num, BLOCK_SIZE, shm_size>>>(input, part, output, n, part_num);
  if (part_num >= 2) {
    PrefixSum(part, part + part_num, part, part_num);
    AddPartSum<<<block_num, BLOCK_SIZE>>>(part, output, n, part_num);
  }
}

162 us -> 120 us (-26%)

小结

以上便是前缀和的一种解题思路。除此之外还有另一种:先计算每个 part 的和,在最后一步做 scan。具体内容见参考链接。

参考

CUDA 高性能计算经典问题(二)—— 前缀和(Prefix Sum)