flashAttention-with-cuda

本文最后更新于:5 个月前

从softmax变形说起

safe softmax(3-pass)

xRNx \in R^Nxx是一个长度为NN的向量,对它进行softmaxsoftmax即有如下公式:

softmax({x1,...,xN})={exij=1Nexj}i=1Nsoftmax(\{x_1,...,x_N\}) = \{\frac{e^{x_i}}{\sum_{j=1}^{N}{e^{x_j}}}\}^{N}_{i=1}

我们知道指数函数它的值随着xx的增长是爆炸性增长的,当xix_i大于某个值就会爆float,那么为了将它约束在fp32,即得到的值是安全的,没有上溢,可以减去一个m(x)=max({x1,...,xN})m(x) = max(\{x_1,...,x_N\}),那么显然我们的xim(x)0x_i-m(x)\le0,则取到的值都是安全的.

那么整个softmax过程,可以由三个循环来表示:

  • 遍历[1,N][1,N],求m(x)m(x),其中m(x)=max{xi}m(x) = \max{\{x_i\}};
  • 遍历[1,N][1,N],求l(x)l(x),其中l(x)=i=1Nexim(x)l(x) = \sum_{i=1}^N{e^{x_i-m(x)}}
  • 遍历[1,N][1,N],求softmax(x)softmax(x),其中softmax(x)=[exim(x)l(x),...,exNm(x)l(x)]softmax(x) = [\frac{e^{x_i-m(x)}}{l(x)},...,\frac{e^{x_N-m(x)}}{l(x)}]

online softmax(2-pass)

上面的safe softmax使用了3次循环,即需要访问gmem3次,如果可以融合中间的计算,就可以减少对gmem的访问次数,可以采用以下思路融合前两个计算:

  • 遍历[1,j][1,j],其中mj(x)=max(mj1(x),xj)m_{j}(x) = \max(m_{j-1}(x),x_j) ,其中jNj\le N;
  • 遍历[1,j][1,j],其中lj(x)=lj1(x)emj1(x)mj(x)+exjmj(x)l_j(x) = l_{j-1}(x)*e^{m_{j-1}(x) - m_j(x)} + e^{x_j - m_j(x)}

上述式子可以进行融合,因此可以在1个循环内完成m(x)m(x)l(x)l(x)的计算

注: 上述式子中的lj(x)l_j(x)的计算其实把它展开来看就很好理解,因为加上一个上一轮的最大值(就是把它消掉),减去这一轮更新的最大值

上述的式子实际上就实现了分段的softmax,即softmax所要处理的一整个向量,不需要计算完全,可以逐段计算,然后利用上述思想进行合并

FlashAttentionV1

1-pass Attention

Attention计算的公式忽略对QKTQK^T做scaled,以及忽略mask,可以得出简化后的Attention公式:

Attention=softmax(QKT)VAttention = softmax(QK^{T})V

而MHA(Multi Head Attention)只是在这个基础上将输入的hidden_dimensionhidden\_dimension切成了num_headnum\_head份的head_dimensionhead\_dimension,即内部的注意力机制的计算都是一样的,只是切成了多份进行并行计算,这里我们只考虑一个头的Attention的优化,因为别的操作一样~

其中Attention中式子产生的中间阵S=QKTS = QK^TP=softmax(S)P = softmax(S)这俩阵分别代表的是pre-softmax logits和注意力得分阵,最终期望得到的是O=PVO = PV.其中这里给定Q,K,V,ORN×dQ,K,V,O \in R^{N \times d},NN表示的是sequence_lengthsequence\_length,dd表示的是head_dimensionhead\_dimension

根据online softmax的思路,可以描述出一个下图的2-pass的公式来计算Attention:

attention-2pass

其中上图中的did_i'其实就是上述的li(x)l_i(x),其中的xix_i就是QQ阵的第kk行与KTK^T的第ii行计算得到的元素,这个元素进行softmax后,作为标量aia_iVV阵第ii行进行相乘,累加到一个向量元素oo上,最终这个值累加完后存放到结果OO阵的第kk行,这里需要解释的点其实就是aiV[i,:]a_iV[i,:]表示的含义(其实就是PkR1×NP_k \in R^{1\times N},与VRN×dV \in R^{N \times d}进行相乘得到的Ok=PkVO_k = P_kV拆分后得到的结果),如下图所示,当i这个变量遍历完整个sequence_lengthsequence\_length则底下OTO^T中的浅蓝块会变成深蓝块,即计算完毕,体现在公式则是oio_i的更新到oNo_N:

attention-2pass-oi

这个2pass的Attention其实第二个循环中的值依赖于第一重循环中得到的mNm_NdNd'_N,由此得出aia_i,即对xix_i做softmax后,与V[i,:]V[i,:]中对应的第i行相乘,从而更新oio_i

如果是以aia_i作为最后所求的值,则无法将2pass融合至1pass,但是我们所求的是O[k,:]=oNO[k,:] = o_N,因此可以试着把中间不断迭代的oio_i展开来看,有:

oi=j=1i(exjmNdNV[j,:])o_i = \sum_{j=1}^{i}{(\frac{e^{x_j-m_N}}{d'_N}V[j,:])}

将它分段来看,从而舍去整一行向量得到的mNm_N,dNd'_N,我们取到第i个元素的值为mim_i,累和为did'_i,则原式子改写成:

oi=j=1i(exjmidiV[j,:])o'_i = \sum_{j=1}^{i}{(\frac{e^{x_j-m_i}}{d'_i}V[j,:])}

经过裂项(裂出第i项,以寻找递推式)并整理可得:

oi=oi1di1emi1midi+eximidiV[i,:]o'_i = o'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d'_i} + \frac{e^{x_i-m_i}}{d'_i}V[i,:]

上面的递推式可以这么理解:将第i-1项得到的结果,将其softmax的原分母di1d'_{i-1}约去,更新内部每个元素减去的最大值,然后除于最新的累和值did'_i,完成前[1,i-1]的更新;并且加上此次计算得到的第i项的结果,得出[1,i]项的结果oio'_i

根据上述递推式,可以使得Attention只需要1pass.过程如下:

attention-1pass

Tiled Attention

flashattnv1-algo

这里的分块,是将QKVO进行分块,它们初始的shape都是RN×dR^{N \times d},想要在1个block内的smem放下,这里假设了smem的大小是MM,通过BcB_c,BrB_r的约束使得共享内存被最大程度的利用(/4d/4d是因为四个阵,它们都沿着N的方向分块,可以使得):

Kj+Vj<M2K_j + V_j \lt \frac{M}{2}

Qi+Oi=dmin(M4d,d)<M2Q_i + O_i = d * min(\frac{M}{4d},d) < \frac{M}{2}

这样划分以尽可能多地利用共享内存

在代码实现中,由于传入的Q阵的tensor是有四个维度的数据:Tensor(b,nh,N,d)分别是batch_size,num_heads,sequence_length,head_dimension,flashattnv1-split

FlashAttentionV2

相较于V1的更新:

  • 公式修改,以利用Tensor Core;
  • 循环顺序调转,变成先切Q后切K,V

FlashAttentionV3

只有阅读,没有hopper

参考文件:


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!