本文最后更新于:5 个月前
从softmax变形说起
safe softmax(3-pass)
x∈RN即x是一个长度为N的向量,对它进行softmax即有如下公式:
softmax({x1,...,xN})={∑j=1Nexjexi}i=1N
我们知道指数函数它的值随着x的增长是爆炸性增长的,当xi大于某个值就会爆float,那么为了将它约束在fp32,即得到的值是安全的,没有上溢,可以减去一个m(x)=max({x1,...,xN}),那么显然我们的xi−m(x)≤0,则取到的值都是安全的.
那么整个softmax过程,可以由三个循环来表示:
- 遍历[1,N],求m(x),其中m(x)=max{xi};
- 遍历[1,N],求l(x),其中l(x)=∑i=1Nexi−m(x)
- 遍历[1,N],求softmax(x),其中softmax(x)=[l(x)exi−m(x),...,l(x)exN−m(x)]
online softmax(2-pass)
上面的safe softmax使用了3次循环,即需要访问gmem3次,如果可以融合中间的计算,就可以减少对gmem的访问次数,可以采用以下思路融合前两个计算:
- 遍历[1,j],其中mj(x)=max(mj−1(x),xj) ,其中j≤N;
- 遍历[1,j],其中lj(x)=lj−1(x)∗emj−1(x)−mj(x)+exj−mj(x)
上述式子可以进行融合,因此可以在1个循环内完成m(x)和l(x)的计算
注: 上述式子中的lj(x)的计算其实把它展开来看就很好理解,因为加上一个上一轮的最大值(就是把它消掉),减去这一轮更新的最大值
上述的式子实际上就实现了分段的softmax,即softmax所要处理的一整个向量,不需要计算完全,可以逐段计算,然后利用上述思想进行合并
FlashAttentionV1
1-pass Attention
Attention计算的公式忽略对QKT做scaled,以及忽略mask,可以得出简化后的Attention公式:
Attention=softmax(QKT)V
而MHA(Multi Head Attention)只是在这个基础上将输入的hidden_dimension切成了num_head份的head_dimension,即内部的注意力机制的计算都是一样的,只是切成了多份进行并行计算,这里我们只考虑一个头的Attention的优化,因为别的操作一样~
其中Attention中式子产生的中间阵S=QKT和P=softmax(S)这俩阵分别代表的是pre-softmax logits和注意力得分阵,最终期望得到的是O=PV.其中这里给定Q,K,V,O∈RN×d,N表示的是sequence_length,d表示的是head_dimension
根据online softmax的思路,可以描述出一个下图的2-pass的公式来计算Attention:

其中上图中的di′其实就是上述的li(x),其中的xi就是Q阵的第k行与KT的第i行计算得到的元素,这个元素进行softmax后,作为标量ai与V阵第i行进行相乘,累加到一个向量元素o上,最终这个值累加完后存放到结果O阵的第k行,这里需要解释的点其实就是aiV[i,:]表示的含义(其实就是Pk∈R1×N,与V∈RN×d进行相乘得到的Ok=PkV拆分后得到的结果),如下图所示,当i这个变量遍历完整个sequence_length则底下OT中的浅蓝块会变成深蓝块,即计算完毕,体现在公式则是oi的更新到oN:
这个2pass的Attention其实第二个循环中的值依赖于第一重循环中得到的mN和dN′,由此得出ai,即对xi做softmax后,与V[i,:]中对应的第i行相乘,从而更新oi
如果是以ai作为最后所求的值,则无法将2pass融合至1pass,但是我们所求的是O[k,:]=oN,因此可以试着把中间不断迭代的oi展开来看,有:
oi=j=1∑i(dN′exj−mNV[j,:])
将它分段来看,从而舍去整一行向量得到的mN,dN′,我们取到第i个元素的值为mi,累和为di′,则原式子改写成:
oi′=j=1∑i(di′exj−miV[j,:])
经过裂项(裂出第i项,以寻找递推式)并整理可得:
oi′=oi−1′di′di−1′emi−1−mi+di′exi−miV[i,:]
上面的递推式可以这么理解:将第i-1项得到的结果,将其softmax的原分母di−1′约去,更新内部每个元素减去的最大值,然后除于最新的累和值di′,完成前[1,i-1]的更新;并且加上此次计算得到的第i项的结果,得出[1,i]项的结果oi′
根据上述递推式,可以使得Attention只需要1pass.过程如下:

Tiled Attention
这里的分块,是将QKVO进行分块,它们初始的shape都是RN×d,想要在1个block内的smem放下,这里假设了smem的大小是M,通过Bc,Br的约束使得共享内存被最大程度的利用(/4d是因为四个阵,它们都沿着N的方向分块,可以使得):
Kj+Vj<2M
Qi+Oi=d∗min(4dM,d)<2M
这样划分以尽可能多地利用共享内存
在代码实现中,由于传入的Q阵的tensor是有四个维度的数据:Tensor(b,nh,N,d)
分别是batch_size,num_heads,sequence_length,head_dimension,
FlashAttentionV2
相较于V1的更新:
- 公式修改,以利用Tensor Core;
- 循环顺序调转,变成先切Q后切K,V
FlashAttentionV3
只有阅读,没有hopper
参考文件: