tensor-core学习

本文最后更新于:2 个月前

tensor core是啥

tensor core是一种专门处理矩阵乘运算的计算单元,可以加速矩阵乘运算,实现混合精度计算,见下图(不难看出这张图就是把数据按原样式搬移到指定地方,然后计算):

Volta架构的tensor core支持fp16和fp32下的混合精度矩阵乘法

  • 可以通过c++的wmma(warp-level matrix multiply accumulate) api调用tensor core,支持如D=AB+CD = AB+CC=AB+CC=AB+C的混合精度矩阵乘法

  • 也可以通过wmma ptx调用tensor core,相关指令如下: (此项废弃,因为wmma的ptx更像是Volta架构初次上新tensor core的尝试,不成熟,后面出的架构都用mma ptx)

    • wmma.load tensor core数据加载指令,将矩阵数据从gmem/smem 加载到寄存器
    • wmma.store
    • wmma.mma
  • 还可以用mma ptx指令调用tensor core:

    • ldmatrix tensor core数据加载指令,支持将矩阵数据从smem加载到寄存器
    • mma

注: 根据别人的逆向工程来看tensor core并没有独立的registerfile,也就是说用的还都是SM上的那个registerfile,跟cuda core用的是同样的寄存器

调用cuda core和tensor core的区别如下:

  • 计算层级: cuda core是thread level; tensor core是warp level;
  • 计算维度: cuda core是一维逐点计算, tensor core是二维逐tile计算(这个可以用C阵的一个元素和一个tile理解)
  • 计算依赖: wmma调tensor core要用fragment这个数据存储类,cuda core不用借助别的啥

实际上,在进行反汇编查看SASS码的时候可以发现,采用的是**LDSM(LoaD Matrix from Shared memory)和HMMA**

下表列出不同微架构下,不同代系的tensor core支持的混合精度运算类型以及单个MMA指令可以处理的矩阵乘大小:

tensor-core-generation-difference

WMMA API

C++调用WMMA的API

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// fragment类模板
template<typename Use, int m, int n, int k, typename T, typename Layout=void>
class fragment;
// 函数声明
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout); // layout体现在二维矩阵是行主序列主序wmma::row_major wmma::col_major
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
// 需要包含的头文件及命名空间:
#include <mma.h>
using namespace nvcuda;
// fragment可以定义的对象如下:
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;

实验

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <cuda_runtime.h>
#include <mma.h>
#include <cstdio>
#include <cstdlib>
#include <cuda_fp16.h>
using namespace nvcuda;

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16
#define WARPSIZE 32

#define checkCudaError(func) { \
cudaError_t e = (func); \
if(e != cudaSuccess){ \
printf("CUDA ERROR %s %d : %s\n",__FILE__,__LINE__,cudaGetErrorString(e)); \
exit(-1); \
} \
}

#define ceilOperation(a,b) ((a+b-1) / b)


/*
1. 采用wmma实现朴素的sgemm,跟朴素的sgemm实现思路差不多
不过先前的方法是一个thread处理C中的一个元素,现在是一个warp处理C中的一个tile(16x16的tile,其中A,B的tile的k也是16)
因为wmma是warp level的mma
*/
__global__ void sgemm_wmma_naive(const half* __restrict__ A, const half* __restrict__ B,
half* __restrict__ C,int M,int N,int K){
// 沿着K方向上的A,B的tile需要遍历的次数
const int kCnt = ceilOperation(K,WMMA_K);
const int bx = blockIdx.x;
const int by = blockIdx.y;
// 这里是1个block只有1个warp
const int warpRow = by * WMMA_M;
const int warpCol = bx * WMMA_N;
if(warpRow > M || warpCol > N)
return;
// 用fragment模板类给数据做包装,C,D阵用accumulator,A,B阵有对应的matrix_a和matrix_b
wmma::fragment<wmma::accumulator,WMMA_M,WMMA_N,WMMA_K,half> cFrag;

wmma::fill_fragment(cFrag,0.0);
for(int k=0;k<kCnt;k++){
wmma::fragment<wmma::matrix_a,WMMA_M,WMMA_N,WMMA_K,half,wmma::row_major> aFrag;
wmma::fragment<wmma::matrix_b,WMMA_M,WMMA_N,WMMA_K,half,wmma::col_major> bFrag;

wmma::load_matrix_sync(aFrag,A+warpRow*K+k*WMMA_K,K);
wmma::load_matrix_sync(bFrag,B+warpCol*K+k*WMMA_K,K);

wmma::mma_sync(cFrag,aFrag,bFrag,cFrag);
}
wmma::store_matrix_sync(C+warpRow*N+warpCol,cFrag,N,wmma::mem_row_major);
}

int main(int argc,char** argv){
if(argc<4){
fprintf(stderr,"Please input 4 elements,like: ./xx M N K\n");
exit(-1);
}
int gpu_id = 2;
cudaDeviceProp prop;

if(cudaGetDeviceProperties(&prop,gpu_id) == cudaSuccess){
printf("Using GPU:%d : %s\n",gpu_id,prop.name);
}

cudaSetDevice(gpu_id);

int M = atoi(argv[1]);
int N = atoi(argv[2]);
int K = atoi(argv[3]);

size_t sizeA = sizeof(half)*M*K;
size_t sizeB = sizeof(half)*K*N;
size_t sizeC = sizeof(half)*M*N;

half* hA = (half*)malloc(sizeA);
half* hB = (half*)malloc(sizeB);
half* hC = (half*)malloc(sizeC);

for(int i=0;i<M*K;i++){
hA[i] = i/256.0f; //i/6;
}

for(int i=0;i<K*N;i++){
hB[i] = (i+1.25f)/128.0f; //(i+1.25f)/2;
}
for(int i=0;i<256;i++)
printf("hA[%d]: %f; hB[%d]: %f\n",i,__half2float(hA[i]),i,__half2float(hB[i]));

half* dA,*dB,*dC;
// float* dC;
checkCudaError(cudaMalloc(&dA,sizeA));
checkCudaError(cudaMalloc(&dB,sizeB));
checkCudaError(cudaMalloc(&dC,sizeC));

checkCudaError(cudaMemcpy(dA,hA,sizeA,cudaMemcpyHostToDevice));
checkCudaError(cudaMemcpy(dB,hB,sizeB,cudaMemcpyHostToDevice));

dim3 block(WARPSIZE);
dim3 grid(ceilOperation(N,WMMA_N),ceilOperation(M,WMMA_M));

sgemm_wmma_naive<<<grid,block>>>(dA,dB,dC,M,N,K);
cudaDeviceSynchronize();
checkCudaError(cudaMemcpy(hC,dC,sizeC,cudaMemcpyDeviceToHost));

for(int i=0;i<10;i++)
printf("hC[%d]:[%.6f]\n",i,__half2float(hC[i]));

checkCudaError(cudaFree(dA));
checkCudaError(cudaFree(dB));
checkCudaError(cudaFree(dC));
free(hA);
free(hB);
free(hC);
cudaDeviceReset();
return 0;
}

注意:由于fp16是E5M10,即它的最大值和最小值是±65504,所以处理的时候需要注意一下控制范围,不然很容易INF

这里的代码等价于1个thread处理1个C阵的元素(即block(n,m)且只开1个block),此处采用的是1个warp处理C阵对应的1个tile(1个block只开1个warp,block的数目与N,M,WMMA_M,WMMA_N相关)

这是利用tensor core进行矩阵乘的最朴素代码

如果采用wmma api进一步做矩阵乘的优化,优化步骤其实也跟cuda core做gemm的优化步骤很像:

  1. 分块.1个block计算某个固定的大小,block内的线程数目开在128,256这些数上(warp多好切换遮掩访存).分块还需要利用上各级存储结构,减少对全局内存的访问次数,即将数据移动到smem,移动到reg上,如果GPU架构允许,还可以使用上异步拷贝;

  2. 软流水.实际上就是双缓冲(数据预取/存算重叠),使得计算不用等长时间的取数,计算访存相互遮掩;

  3. L2Cache的局部性.当矩阵很大时,会启动非常多的block,这种情况下SM调度的顺序是沿着执行配置参数中的grid的Z->Y->X这三个维度的循环进行的.以RTX3090为例,共82个SM,如果按照M,N,K = (16384,16384,16384)分块维度BM=128B_M=128,BN=256B_N=256,BK=32B_K=32来看,共有128x64个block,且每个SM只能调度1个block,如果按照grid(64,128)来排block的话,会得到一个非常差劲的性能,如果填补上Z维,如grid(16,128,4)则会好很多.见下图,显然右侧的block调度的方式更优(图文可能没那么符合)

    wmma-gemm-l2cache

下面给出了利用wmma api做gemm的分块步骤,相较于之前的不同在于warp level以及warp内部需要分成多个fragment用于存储对应的数据

wmma-gemm

本部分优化的详细代码可参考:OperatorDev/wmma

那bank conflict呢?这里可以用padding来处理,更优的做法下面会讨论到👀

padding对于smem bank conflict的解决是对给定的smem的数组加"元素",一般这个加的元素数目与相关数据类型的向量化访问大小有关,不然你用上向量化访存,在读取某处时会发生misaligned address的错误;那padding这个东西其实影响大不大?个人感觉与smem所负责的分块大小有关,因为加的元素影响的是smem,如果smem负责的分块较大(怎么感觉在说废话,要移到smem上的肯定充分用满,那不移到smem上的算子,肯定数据移动量相对较小,移动到smem都是负优化),那么padding所浪费的smem其实是较大的一块,会影响到一个SM上并发的block数目,影响到occupancy,那么有些计算资源也没给用上.那这就很浪费了.

MMA PTX

ldmatrix指令

ldmatrix指令用于将数据从smem搬运到reg上,由于在搬运的时候我们为了减少访存指令数,一般都会采用向量化访存即float4,此时则以[T0,T7]为一个memory transaction,当然这里暂时不讨论bank conflict。而一个[T0,T7]中每个线程负责一个16B的数据,即搬运了一整个8x8的小矩阵。搬运得到的小矩阵是由warp内每个线程的R0寄存器保存矩阵的一部分,整个小矩阵由整个warp共同表达,这与朴素的LDS.128指令不同,不是一个线程各做各的,各自搬到各自私有的寄存器上。ldmatrix打破了这种私有的限制,由一整个warp协作式加载这个8x8小矩阵

ldmatrix指令解释如下:

1
2
3
4
5
6
7
8
9
10
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];

.shape = {.m8n8};
.num = {.x1, .x2, .x4};
.ss = {.shared{::cta}};
.type = {.b16};

// trans表示转置,ss表示状态空间
// shape表示1次加载的小矩阵块是8x8的 最多支持x4,即4个8x8的小矩阵块加载
// type表示每个数据元素是16bit

具体的数据搬移过程见下面两个图:

how-to-ldmatrix how-to-ldmatrix2

通过上图不难看出,我们如果想要获得一个16x16的矩阵,我们可以让T0-T31分别对应smem的某一个起始地址,然后通过它们一次性读入16B数据,即8个元素,然后再分配给同一warp其他线程的寄存器.寄存器是32位的,1个寄存器可以放下2个fp16的数据元素.即可得到右图所示的数据排放

实验

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <cstdio>
#include <cstdint>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

#define SHAPE 8*8*4

#define checkCudaError(func) { \
cudaError_t e = (func); \
if(e!=cudaSuccess){ \
printf("CUDA ERROR %s %d : %s\n",__FILE__,__LINE__,cudaGetErrorString(e)); \
} \
}

#define FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])

__global__ void test_ldmatrix(half* dIn,half* dOut){
__shared__ half smemTmp[SHAPE];
int idx = threadIdx.x * 8;
// 把数据从gmem -> smem,一个线程取一行数据,即128bit共8个fp16,也即是4个fp32,可以用float4一次取完
FLOAT4(smemTmp[idx]) = FLOAT4(dIn[idx]);
// printf("[%d]:[%f]\n",threadIdx.x,__half2float(smemTmp[idx])); // smem数据是按要求摆放的

uint32_t reg[4];
// 使用ldmatrix, smem -> reg
__asm__ volatile(
"ldmatrix.sync.aligned.m8n8.x4.b16 {%0,%1,%2,%3},[%4]; \n"
: "=r"(reg[0]),
"=r"(reg[1]),
"=r"(reg[2]),
"=r"(reg[3])
: "l"(&smemTmp[idx])
);

// 把每个线程寄存器得到的东西给它移动到我们的dOut中 reg -> gmem
*(reinterpret_cast<float*>(&dOut[idx])) = *(reinterpret_cast<float*>(&reg[0]));
*(reinterpret_cast<float*>(&dOut[idx+2])) = *(reinterpret_cast<float*>(&reg[1]));
*(reinterpret_cast<float*>(&dOut[idx+4])) = *(reinterpret_cast<float*>(&reg[2]));
*(reinterpret_cast<float*>(&dOut[idx+6])) = *(reinterpret_cast<float*>(&reg[3]));
}


int main(){
// 数组是 fp16,4个8*8的阵,用以查看ldmatrix中.num为.x4的数据排布
half hLoad[SHAPE];
half hStore[SHAPE];
// 初始化data,为[0,255]的整型
for(int i=0;i<SHAPE;i++){
hLoad[i]=i;
}

size_t sizeData = sizeof(half) * (SHAPE);
half* dLoad,*dStore;

checkCudaError(cudaMalloc(&dLoad,sizeData));
checkCudaError(cudaMalloc(&dStore,sizeData));

checkCudaError(cudaMemcpy(dLoad,hLoad,sizeData,cudaMemcpyHostToDevice));
test_ldmatrix<<<1,32>>>(dLoad,dStore);
checkCudaError(cudaMemcpy(hStore,dStore,sizeData,cudaMemcpyDeviceToHost));

for(int i=0;i<32;i++){
printf("T%02d ",i);
for(int j=0;j<8;j+=2){
printf("r%d %12f %12f ",j/2,__half2float(hStore[i*8+j]),__half2float(hStore[i*8+j+1]));
}
printf("\n");
}

checkCudaError(cudaFree(dLoad));
checkCudaError(cudaFree(dStore));
return 0;
}
ldmatrix_no_trans ldmatrix_with_trans

以上实验是对ldmatrix指令的整体认知,主要体现在整个warp协作式加载数据,使得整个矩阵由整个warp的每一个线程的部分寄存器来进行表示.其中加了.trans的个人建议理解先转置,后按照无.trans的情况处理,则可以与mma指令所需的matrixB的数据排布对应上

mma指令

看完上面的ldmatrix指令,你会好奇,为啥要搞得这么复杂呢?

实际上ldmatrix是为了和mma指令搭配使用的,看下mma指令在M16N8K16情况下的1个warp涉及到的C tile和A tile以及B tile的数据在线程上的排布情况,你就知道为啥这么复杂了:

hmma_16816_layoutA hmma_16816_layoutB hmma_16816_layourCorD

很明显,mmaldmatrix指令读取后数据分布到线程的情况是一样的,因此可以很好的搭配起来使用.那么如何结合起来使用呢?

需要注意到,ptx中mma的表示如下(这里贴的是半精度的):

1
2
3
4
5
6
7
8
mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype  d, a, b, c;
mma.sync.aligned.m16n8k8.row.col.dtype.f16.f16.ctype d, a, b, c;
mma.sync.aligned.m16n8k16.row.col.dtype.f16.f16.ctype d, a, b, c;

.alayout = {.row, .col};
.blayout = {.row, .col};
.ctype = {.f16, .f32};
.dtype = {.f16, .f32};

现在一般都是matrix_a是row-major,matrix_b是col-major,这也可以从上面的A,B,C阵的layout看出来.以16816为例我们知道对于matrix_a而言它需要先加载一个16x16大小的矩阵到寄存器上,这里以.x4ldmatrix执行操作即可,其中T0-T15取列为0时的行首地址,T16-T31取列为8时的行首地址;对于matrix_b而言它在warp中的线程上的数据排布稍显奇怪,T0-T3中的寄存器RB0们的元素呈现col-major,回看上面ldmatrix的实验可知,就是将matrix_b s2r的时候加了个.trans限定符即可,其中T0-T15取列为0时的行首地址(.x2就行)

以下是一个mma的例子,代码看test_mma:

hmma_matrix_a

hmma_matrix_b

hmma_accumulator

上面这个例子中的MNK取了16816,恰好使用一条mma指令即可完成,源码细节在链接中查看

ldmatrix之bank conflict free

回答上面的问题,那我们这样取数据(LDSM,就是ldmatrix)会有bank conflict嘛?

以下图16x16,16x32,16x64的这么些个阵,给它从gmem放到smem上,然后分别由1,2,4个warp处理,它们分别会发生不同的n-way bank conflict(图中是以LDS.128来描述此时的bank排布).显然我们不能将数据从gmem朴素的移到smem上而不做任何位置上的调动,这样会导致大量的smem bank conflict.

bank conflict-from lds128 to see ldsm

本质上,我们只需要对同一行的新元素们(以fp16来说这个新元素就是由8个连续的fp16元素组成的一个块)做shuffle就可以实现bank conflict free,注意,这只是改变行内的新元素排列的顺序.而这个行内的界定又取决于连续访问新元素的能力,即此时的新bank数目为所谓的一行.

以cutlass来说,所用的手段就是icol = irow ^ icol,我们的每个线程将数据从gmem搬到smem时,每个线程是负责搬运一定的大小,有对应的行列,进行一定的操作(含异或),便可以实现类似下图的搬运(g2s);然后为了利用上LDSM,根据warpId,laneId以及ldmatrix的取数顺序,即laneId % 16做行, laneId / 16散开成列的0,1,这里的0,1是以新元素的大小来描述的,实际要乘上个8,还有根据warpId对不同warp搬运的位置进行定位,即可实现smem bank conflict free的LDSM

bank conflict with swizzle(xor)

以上实验代码详见:OperatorDev/test_swizzle.cu,实验效果分析如下:

朴素情况下,对16x64的4warp调用ldsm,会发生7(单个memory transaction的访问冲突数目)*4(1个warp的memory transaction数)*4(warp数目)*1(block数目)=112次访问冲突

naive_ldsm

经过手动swizzle后,变为了0次:

swizzle_ldsm

ncu中的两个指标

  • wavefront类似于cycle的表述,它代表了处理请求所需的时钟周期数,比如你无smem bank conflict,则我就是1个cycle完成访问,否则因为并行转为串行,n-way bank conflict就会使得你需要n个cycle,而这n个cycle乘上1个warp里的memory transaction数目,再乘上warp数,再乘上block数即是ncu给出的值
  • bank conflict之于ncu,比如说4-way bank conflict,那就是一个memory transaction有3次冲突,单个memory transaction的冲突数乘上一个warp内有的memory transaction数目,再乘上warp数和block数,即是ncu输出的smem bank conflict数目

这个知识点目前存在的一个问题,如何封装成函数模板?想按照cutlass那样封,但自己没想到咋计算对应的关系,只会对给出的用例计算how to achieve bank conflict free,无法更高的抽象化,后续继续学习

参考文件:


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!