CUDA 实践:矩阵乘法 | CUDA

本文介绍了使用 CUDA 实现矩阵乘法的思路与优化方法。

概述

关于矩阵乘法,相关的文章比较多,这里推荐这个系列:

本文也是参考自上述三篇文章。

代码实现

Baseline

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(SIZE, SIZE);
// dim3 grid(n / SIZE, m / SIZE);

int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
int idx_y = blockIdx.y * blockDim.y + threadIdx.y;

float sum = 0.f;
if (idx_x < n && idx_y < m) {
for (int kk = 0; kk < k; ++kk) {
sum += A[idx_y * k + kk] * B[kk * n + idx_x];
}
C[idx_y * n + idx_x] = sum;
}
}

使用 shared memory

将每个 block 需要计算的数据先存放到 shared memory 中,减少对 global memory 的访存次数。

在这段代码中,每个 SIZE x SIZE 的 block 负责计算 C 中 SIZE x SIZE 的块,在外层循环的每次迭代中,首先将 A 和 B 中对应的块拷贝到共享内存,然后基于共享内存中的数据进行矩阵乘法,然后进行下一次迭代并将每次迭代的结果累加,最终得到 C 中对应的一块。在每次循环中,共享内存都会被更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(SIZE, SIZE);
// dim3 grid(n / SIZE, m / SIZE);

__shared__ float s_a[SIZE][SIZE];
__shared__ float s_b[SIZE][SIZE];

int idx_x = blockIdx.x * blockDim.x + threadIdx.x;
int idx_y = blockIdx.y * blockDim.y + threadIdx.y;

float sum = 0.0;
for (int bk = 0; bk < k; bk += SIZE) {
s_a[threadIdx.y][threadIdx.x] = A[idx_y * k + (bk + threadIdx.x)];
s_b[threadIdx.y][threadIdx.x] = B[(bk + threadIdx.y) * n + idx_x];
__syncthreads();

for (int i = 0; i < SIZE; ++i) {
sum += s_a[threadIdx.y][i] * s_b[i][threadIdx.x];
}
__syncthreads();
}

if (idx_x < n && idx_y < m) {
C[idx_y * n + idx_x] = sum;
}
}

每个线程处理多个数据

新的分块方式

在上面的计算中,每个线程负责一个输出矩阵元素的计算,接下来的实现将使每个线程处理多个数据,这样可以大大提升计算访存比。足够大的计算访存比能提升计算单元的利用率,并能起到隐藏访存延迟的作用。本节使用了新的矩阵分块方式,分块的大小不再和 block size 相同,并且对共享内存也进行分块。

对于 bm、bn、bk、rm、rn 这几个参数,这里取 bm=128、bn=128、bk=8、rm=8、rn=8,这几个参数的选取逻辑可以参考 CUDA 矩阵乘法终极优化指南。当这几个参数选定之后先来直观地感受一下这几个参数意义:假定给了三个矩阵 A、B、C,其维度都是 2048x2048。要求 C=AxB。那么我们需要开启 (2048/128)x(2048/128)=256 个 block,每个 block 里面有 (128/8)x(128/8)=256 个线程,每个线程需要负责计算 C 矩阵中 8x8=64 个元素的结果,每个 block 负责 256×64=16384 个元素的结果。

总的来说,对于一个 block 而言,有 256 个大迭代,每个大迭代中又有 8 个小迭代。

数据搬运

一个线程负责多个元素的计算,同样需要负责多个元素的搬运(从全局内存到共享内存),数据搬运的示意图如下:

在这个例子中,一共使用了 256 个线程,在一次大迭代中需要将 128*8 个元素搬运到共享内存中。用下列几个参数说明搬运的逻辑:

  • A_TILE_THREAD_PER_ROW 代表搬运一行数据需要使用多少个线程,为了搬运 A 的一行,需要使用 2 个线程
  • A_TILE_ROW_START 代表在数据块 As 中,当前线程需要搬运的元素数据块的竖向坐标,A_TILE_COL 代表需要搬运的数据的横向坐标。以 3 号线程为例,它负责搬运坐标为 (1,1) 的数据块中的 4 个元素,所以 A_TILE_ROW_START 是 1,A_TILE_COL 是 4
  • A_TILE_ROW_STRIDE 代表在进行多次搬运时需要跨越的行。对于 256*8 的数据块 As,使用 256 个线程进行搬运,一次搬运 4 个元素数,所以要搬运两次。对于 3 号线程而言,负责搬运图中的两个绿色数据块

使用数据预取(double buffer)

代码示例

可以使用 #pragma unroll 宏展开循环,注意只对循环起止和循环步长在编译期就确定的循环有效。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
// dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);

__shared__ float s_a[BLOCK_SIZE_M][BLOCK_SIZE_K];
__shared__ float s_b[BLOCK_SIZE_K][BLOCK_SIZE_N];
float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};

const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;

// 每个线程一次只搬运一个数据

// 在 s_a/s_b 中,当前线程需要搬运的第一个数据的横纵坐标
const int A_TILE_ROW = tid / BLOCK_SIZE_K;
const int A_TILE_COL = tid % BLOCK_SIZE_K;
const int B_TILE_ROW = tid / BLOCK_SIZE_N;
const int B_TILE_COL = tid % BLOCK_SIZE_N;

// 在进行多次搬运时需要跨越的行
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_N;

for (int bk = 0; bk < k; bk += BLOCK_SIZE_K) {
// load A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW;
const int col = bk + A_TILE_COL;
if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
s_a[i + A_TILE_ROW][A_TILE_COL] = row < m && col < k ? A[row * k + col] : 0;
} else {
s_a[i + A_TILE_ROW][A_TILE_COL] = A[row * k + col];
}
}

// load B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = bk + i + B_TILE_ROW;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
s_b[i + B_TILE_ROW][B_TILE_COL] = row < k && col < n ? B[row * n + col] : 0;
} else {
s_b[i + B_TILE_ROW][B_TILE_COL] = B[row * n + col];
}
}

__syncthreads();

// 每个线程负责搬运的数据和接下来要计算的数据没有必然联系

// calculate C
#pragma unroll
for (int kk = 0; kk < BLOCK_SIZE_K; ++kk) {
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += s_a[THREAD_SIZE_Y * threadIdx.y + ty][kk] * s_b[kk][THREAD_SIZE_X * threadIdx.x + tx];
}
}
}

__syncthreads();
}

// store back to C
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
if (blockIdx.x == gridDim.x - 1 || blockIdx.y == gridDim.y - 1) {
if (row < m && col < n) {
C[row * n + col] += r_c[ty][tx];
}
} else {
C[row * n + col] += r_c[ty][tx];
}
}
}
}

利用向量化指令

可以尝试引导编译器使用 LDG.128 和 STG.128 指令来加速数据 IO。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// 通过这种方式可以引导编译器使用 LDG.128 指令
#define FETCH_FLOAT4(p) (reinterpret_cast<float4*>(&(p))[0])

__global__ void matrixMultiplyKernel(float *A, float *B, float *C, int m, int n, int k) {
// dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
// dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);

__shared__ float s_a[BLOCK_SIZE_M][BLOCK_SIZE_K];
__shared__ float s_b[BLOCK_SIZE_K][BLOCK_SIZE_N];
float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
float frag_a[THREAD_SIZE_Y];
float frag_b[THREAD_SIZE_X];

const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;

// 每个线程一次搬运四个数据

// 在 s_a/s_b 中,当前线程搬运一行数据需要的线程数
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

// 在 s_a/s_b 中,当前线程需要搬运的第一个数据组中第一个数据(即四个数据的第一个)的的横纵坐标
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

// 在进行多次搬运时需要跨越的行
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

for (int bk = 0; bk < k; bk += BLOCK_SIZE_K) {
// load A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
const int col = bk + A_TILE_COL;
FETCH_FLOAT4(s_a[i + A_TILE_ROW_START][A_TILE_COL]) = FETCH_FLOAT4(A[row * k + col]);
}

// load B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = bk + i + B_TILE_ROW_START;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
FETCH_FLOAT4(s_b[i + B_TILE_ROW_START][B_TILE_COL]) = FETCH_FLOAT4(B[row * n + col]);
}

__syncthreads();

// 每个线程负责搬运的数据和接下来要计算的数据没有必然联系

// calculate C
#pragma unroll
for (int kk = 0; kk < BLOCK_SIZE_K; ++kk) {
// load A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
frag_a[ty] = s_a[THREAD_SIZE_Y * threadIdx.y + ty][kk];
}

// load B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[tx]) = FETCH_FLOAT4(s_b[kk][THREAD_SIZE_X * threadIdx.x + tx]);
}

#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += frag_a[ty] * frag_b[tx];
}
}
}
}

// store back to C
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
FETCH_FLOAT4(C[row * n + col]) = FETCH_FLOAT4(r_c[ty][tx]);
}
}
}

使用 double buffer

使用 double buffer 可以进一步掩盖访存延迟。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
__global__ void matrixMultiplyKernel(float * A, float * B, float * C, int m, int n, int k) {
// dim3 block(THREAD_X_PER_BLOCK, THREAD_Y_PER_BLOCK);
// dim3 grid(n / BLOCK_SIZE_N, m / BLOCK_SIZE_M);

__shared__ float s_a[2][BLOCK_SIZE_K][BLOCK_SIZE_M];
__shared__ float s_b[2][BLOCK_SIZE_K][BLOCK_SIZE_N];
float r_c[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
float frag_a[2][THREAD_SIZE_Y];
float frag_b[2][THREAD_SIZE_X];

// 为了存储 BLOCK_SIZE_M * BLOCK_SIZE_K 的数据块,每个线程需要额外开启 ldg_a_reg 个寄存器进行存储
float ldg_a_reg[BLOCK_SIZE_M * BLOCK_SIZE_K / THREAD_NUM_PER_BLOCK];
float ldg_b_reg[BLOCK_SIZE_K * BLOCK_SIZE_N / THREAD_NUM_PER_BLOCK];

const int tid = threadIdx.y * THREAD_X_PER_BLOCK + threadIdx.x;

// 每个线程一次搬运四个数据

// 在 s_a/s_b 中,当前线程搬运一行数据需要的线程数
const int A_TILE_THREAD_PER_ROW = BLOCK_SIZE_K / 4;
const int B_TILE_THREAD_PER_ROW = BLOCK_SIZE_N / 4;

// 在 s_a/s_b 中,当前线程需要搬运的第一个数据组中第一个数据(即四个数据的第一个)的的横纵坐标
const int A_TILE_ROW_START = tid / A_TILE_THREAD_PER_ROW;
const int A_TILE_COL = tid % A_TILE_THREAD_PER_ROW * 4;
const int B_TILE_ROW_START = tid / B_TILE_THREAD_PER_ROW;
const int B_TILE_COL = tid % B_TILE_THREAD_PER_ROW * 4;

// 在进行多次搬运时需要跨越的行
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / A_TILE_THREAD_PER_ROW;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / B_TILE_THREAD_PER_ROW;

// preload A from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
const int col = A_TILE_COL;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[row * k + col]);
s_a[0][A_TILE_COL + 0][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 0];
s_a[0][A_TILE_COL + 1][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 1];
s_a[0][A_TILE_COL + 2][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 2];
s_a[0][A_TILE_COL + 3][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 3];
}

// preload B from global memory to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = i + B_TILE_ROW_START;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
FETCH_FLOAT4(s_b[0][i + B_TILE_ROW_START][B_TILE_COL]) = FETCH_FLOAT4(B[row * n + col]);
}

__syncthreads();

// preload A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
FETCH_FLOAT4(frag_a[0][ty]) = FETCH_FLOAT4(s_a[0][0][THREAD_SIZE_Y * threadIdx.y + ty]);
}

// preload B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[0][tx]) = FETCH_FLOAT4(s_b[0][0][THREAD_SIZE_X * threadIdx.x + tx]);
}

int write_stage_idx = 1;
int bk = 0;
do {
bk += BLOCK_SIZE_K;

if (bk < k) {
// preload A from global memory to register
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW_START;
const int col = bk + A_TILE_COL;
FETCH_FLOAT4(ldg_a_reg[ldg_index]) = FETCH_FLOAT4(A[row * k + col]);
}

// preload B from global memory to register
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
const int row = bk + i + B_TILE_ROW_START;
const int col = BLOCK_SIZE_N * blockIdx.x + B_TILE_COL;
FETCH_FLOAT4(ldg_b_reg[ldg_index]) = FETCH_FLOAT4(B[row * n + col]);
}
}

// 每个线程负责搬运的数据和接下来要计算的数据没有必然联系

int load_stage_idx = write_stage_idx ^ 1;

// calculate C
#pragma unroll
for (int kk = 0; kk < BLOCK_SIZE_K - 1; ++kk) {
// preload A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
FETCH_FLOAT4(frag_a[(kk + 1) % 2][ty]) = FETCH_FLOAT4(s_a[load_stage_idx][kk + 1][THREAD_SIZE_Y * threadIdx.y + ty]);
}

// preload B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[(kk + 1) % 2][tx]) = FETCH_FLOAT4(s_b[load_stage_idx][kk + 1][THREAD_SIZE_X * threadIdx.x + tx]);
}

// calculate C (this tile)
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += frag_a[kk % 2][ty] * frag_b[kk % 2][tx];
}
}
}

if (bk < k) {
// preload A from register to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_M; i += A_TILE_ROW_STRIDE) {
int ldg_index = i / A_TILE_ROW_STRIDE * 4;
s_a[write_stage_idx][A_TILE_COL + 0][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 0];
s_a[write_stage_idx][A_TILE_COL + 1][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 1];
s_a[write_stage_idx][A_TILE_COL + 2][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 2];
s_a[write_stage_idx][A_TILE_COL + 3][i + A_TILE_ROW_START] = ldg_a_reg[ldg_index + 3];
}

// preload B from register to shared memory
#pragma unroll
for (int i = 0; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
int ldg_index = i / B_TILE_ROW_STRIDE * 4;
FETCH_FLOAT4(s_b[write_stage_idx][B_TILE_ROW_START + i][B_TILE_COL]) = FETCH_FLOAT4(ldg_b_reg[ldg_index]);
}

__syncthreads();
write_stage_idx ^= 1;
}

// preload A from shared memory to register
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ty += 4) {
FETCH_FLOAT4(frag_a[0][ty]) = FETCH_FLOAT4(s_a[load_stage_idx ^ 1][0][THREAD_SIZE_Y * threadIdx.y + ty]);
}

// preload B from shared memory to register
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
FETCH_FLOAT4(frag_b[0][tx]) = FETCH_FLOAT4(s_b[load_stage_idx ^ 1][0][THREAD_SIZE_X * threadIdx.x + tx]);
}

// compute last tile matmul THREAD_SIZE_X * THREAD_SIZE_Y
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; ++tx) {
r_c[ty][tx] += frag_a[1][ty] * frag_b[1][tx];
}
}

} while(bk < k);

// store back to C
#pragma unroll
for (int ty = 0; ty < THREAD_SIZE_Y; ++ty) {
#pragma unroll
for (int tx = 0; tx < THREAD_SIZE_X; tx += 4) {
const int row = BLOCK_SIZE_M * blockIdx.y + THREAD_SIZE_Y * threadIdx.y + ty;
const int col = BLOCK_SIZE_N * blockIdx.x + THREAD_SIZE_X * threadIdx.x + tx;
FETCH_FLOAT4(C[row * n + col]) = FETCH_FLOAT4(r_c[ty][tx]);
}
}
}

性能对比

设备信息:NVIDIA Tesla T4, CUDA 11.1

对于 m = n = k = 1024,性能数据如下:

  • baseline:4.05 ms
  • 使用 shared memory:2.40 ms
  • 每个线程处理多个数据
    • without unroll:0.82 ms
    • with unroll:0.65 ms
  • 利用向量化指令:0.57 ms
  • 使用 double buffer:0.54 ms

更多矩阵尺寸性能对比:

完整代码在 zh0ngtian/cuda_learning

TODO

  • 解决 bank conflict
  • 使用 tensor core 优化

参考

深入浅出GPU优化系列:GEMM优化(一)

深入浅出GPU优化系列:GEMM优化(二)

CUDA 实践:矩阵乘法 | CUDA

http://www.zh0ngtian.tech/posts/975c867a.html

作者

zhongtian

发布于

2022-04-01

更新于

2024-01-05

许可协议

评论