ViT理解及实验
本文最后更新于:1 年前
引言
ViT(Vision Transformer),是ICLR 2021上的一篇AN IMAGE IS WORTH 16x16 WORDS:TRANSFORMER FOR IMAGE RECOGNITION AT SCALE论文里提及的模型,将transformer应用于image classification上,并在Google的JFT-300M数据集进行预训练后,在ImageNet-1k上做分类达到当时的SOTA!
因此,本文将通过结构结合代码进行介绍,并于文末将采用与预训练好的模型应用于花分类数据集上进行实验.
注:本文采用的是ViT-B/16模型进行解析,B->Base 16->patch size:16*16,输入图片的shape是(224,224,3)
原论文链接:AN IMAGE IS WORTH 16x16 WORDS:TRANSFORMER IMAGE RECOGNITION AT SCALE
花分类数据集下载
ViT整体结构
首先看下ViT的整体结构图:
整个ViT由以下三部分组成:
- Linear Projection of Flattened Patches -> 转换为transformer接受的输入
- Transformer Encoder -> 进入Encoder学习图片的相关性,对应特征
- MLP Head -> 做分类,其实就是全连接层
Linear Projection of Flattened Patches
patch embedding
根据先前介绍的transformer的知识,我们输入的应该是token sequence,如由词向量组成的矩阵,row的个数表示词的个数,column宽度表示词的dimension.
因此输入图片是不符合要求的,我们需要将图片处理成对应的序列才行.将图片切割成一小块一小块的patch,再将patches延展成一维的(获得patch_num),再将每个patch通过线性变换映射到对应的维度(patch_dim),即完成了patch embedding的过程.最终得到的参数为:[patch_num,patch_dim]
其实上述过程,用卷积的思想解释就是kernel=16x16,stride=16,采用768个卷积核进行卷积计算,即可以将[224,224,3]->[14,14,768],然后再通过torch的flatten处理就可以得到[196,768]
注:这里的768并非通过如16*16*3计算出来的,而是作者规定的dimension,如ViT-H/16,其patch_dim是1280.
1 |
|
注: 代码中的类继承了nn.Module,这里面重写了魔法方法__call__,该方法里面调用了forward方法,因此子类重载forward可使得该方法通过类实例化对象如普通方法般调用.
1 |
|
CLS token
ViT中的分类采用的是Bert中的CLS的思想来进行的,因此我们的patchEmbedding需要变换成**[196+1,768]**,即多出一行用于分类,这个参数由网络学习得到.其与patchEmbedding是concat的关系
positional embedding
在transformer里讲过,我们输入的序列缺乏位置信息,因此需要增加positional embedding来使得其位置信息得以保持.在代码中,我们的位置信息是通过模型训练获得的(nn.Parameter()),其与patch embedding是直接add的
encoder block
流程如下图所示:
关于Multi-Head Attention层相关的部分跟transformer的一样,这里不赘述.以下是Multi-Head Attention的代码解析部分:
1 |
|
nn.Linear()就是全连接(如,是输入参数,可以是多维的,是权重矩阵,是偏置向量),对x.shape的最后一个维度最连接操作,也可以将其理解为矩阵乘,其参数由训练迭代更新(源码中将其加入了nn.Parameter()中)
nn.Droupout()防止过拟合,参数表示不被激活的神经元的占比
@->矩阵乘 *->矩阵点乘(又或者说逐向量内积)
MLP block
从上方ViT Encoder Block的图中可以看出,MLP是由两层全连接层和GELU激活函数构成,第一层将其维度变为,第二层将其维度变回768(没搞明白为啥这样变).代码如下:
1 |
|
MLP HEAD
切片,取第0行的向量(CLS),然后做全连接,流程如下图所示
关于上面的Pre-Logits,在ImageNet-1K无需用到,直接设置为None;在ImageNet-21K里用到了,就是一个全连接层+激活函数
ViT执行流程
详见下图,图片源于B站UP主:霹雳吧啦Wz
ViT实验
拿的别人的模型跑的,但是好像用CUDA的地方出了点问题,在调BUG,跑好了会摆上来
ViT与传统CNN的差异
关于二者的差异在Do Vision Transformers See Like Convolutional Neural Networks这篇论文中进行了详细的比对.下面将对其结果和主要使用分析工具进行介绍
原论文连接:Do Vision Transformers See Like Convolutional Neural Networks?
非常粗粒度的看了这篇论文
主要分析工具
CKA(Centered Kernel Alignment),可用于计算神经网络表征(neural network representation)的相似度.可以用于计算同一模型不同层间的表征相似度,或者是不同模型间的层的表征相似度,具体计算公式不细说(主要是我没弄明白),知其定义即可理解该论文对ViT和ResNet间的差异性分析.
ViT和ResNet的比较
同一模型间的不同层级的表征相似性比较
从上面这个图可以很清晰的看出来:ViT模型整体具有高度的表征相似性,而ResNet只是局部具有较高的表征相似性,如低层和高层,且lower layers和higher layers相似性很小
不同模型间的对应层级的表征相似性比较
从上面这个heatmap图可以看出来:ResNet中较多的较低层与ViT中较少的较低层具有相似性,但是在高层间(最高的那一片区域)二者相似性很低
以上对模型自身及模型间的层与层的表征相似性的说明了:
- ViT最高层对于ResNet来说,具有相当不同的表示
- ViT在较低层和较高层间的传播表示更为强烈
- ViT计算较低层表征的方式与ResNet较低层的不同
层表征中的局部和全局信息
这个标题主要的意思就是,不同模型中低层和高层中学习到了的是什么样的信息,局部的或是全局的或是二者皆有
因此这里的层表征主要指代的就是低层和高层的,于ViT而言就是最开始的encoder block和最后一次的encoder block
之于CNN而言,我们知道,它卷积其实受限于相邻的区域,即使stride有所调整,也是局部的(于低层而言,我觉得高层也是,某一块kernel对应的是一只猫的耳朵或是尾巴,其实也是局部信息)
以下是ViT的低层和高层部分的学习到的信息的图(注意,这里的Mean Distance是一种用单个头的注意力权重,就是之前说的Zi(i∈[0,15])来加权pixel distance(我觉得是对应patch的dimension),然后做平均得到的结果.大距离说明是全局信息,小距离说明是局部信息)
显而易见,encoder block在低层时,学到的既有局部信息也有全局信息,而高层的encoder block学习到的都是全局信息(个人认为就是自注意力机制脱离了邻域关注的问题,使得低层次也可以学习到全局的信息)
当然,这只是说明ViT跟CNN学习方式不同,并非说局部信息不好的意思.论文里作者也用了没有pretrain的ViT,其效果很烂之余,也发现其压根没学到啥局部信息,反而印证了前期学习中局部信息的重要性,效果图如下:
那么全局信息有啥用呢,作者通过对encoder block 1和2里面各自的16个head划分成多个子集,子集范围对应着[多数含局部信息的heads,多数含全局信息的heads],用这些子集和低层的ResNet计算CKA,得出下图:
结果显而易见,随着全局信息增多(即平均距离增大),基本上二者CKA单调递减
作者其实没明确给出ViT中encoder block低层且head较小时学到的全局信息有啥用,但我觉得这个局部信息(基于Mean Distance这种度量方式)可能正是源于我们多头机制想要规避开的对于自身所在词的过度关注,而随着后面cocat成一个完整的Z与WO做乘积时,其局部性被削弱(这也跟Multi-Head Attention这个机制有关)
而此刻输出值作为下一个encoder block的输入,其包含了较多的全局信息,这使得我们较低层与较高层构建出一定的相似性,当然这也与skip connection有关(其实这里说的就是skip connection)
接着作者还对有效感受野(ERF)进行了分析,如下图示:
我们知道,卷积的有效感受野受kernel大小以及下采样层影响,因此一开始很小;而自注意力机制使得ViT的有效感受野不受局部信息局限,还多了全局信息,因此有效感受野比较大;
而之后ResNet的ERF以局部扩增的方式增大(高度局部化),而ViT的ERF则是从局部转向全局,且高度依赖于中心的patch,这与skip connection有强烈关系!下图是pre-residual的感受野:
可以看出上图(比较Attention12),可以看出残差连接制约着感受野对于中心patch的依赖性
skip connection在ViT中发挥的作用
根据先前的ViT不同层级做CKA进行相似性比较的图,我们知道了它的表征具有高度一致性,这是由我们这里要讨论的skip connection 发挥的作用
我们通过范数比:来进行探讨,其中是来自于skip connection的第i个层的hidden representation,是经过long branch后的值,这里的long branch指的是MLP或是self-attention
至此,我们知道了,若是范数比比值大则意味着skip connection起主要作用,若是范数比比值小则意味着long branch起主要作用
以下是根据范数比所作的heatmap,需注意CLS token和别的spatial token是分开来讨论的:
根据左图,显而易见,CLS token(token[0])和别的spatial token的受影响方式恰好相反
CLS token是网络前期(Block index小的部分)范数比大,即受skip connection影响大,而网络后期则是受long branch影响大,spatial token则相反
根据右图,除了彰显了上述结论,也可以看出ViT较ResNet受skip connection影响更大
作者又做了个干预性实验来证明skip connection对ViT表征结构高度一致性的影响,即移除中间某一个block的skip connection,图示如下:
可见,若是移除了某一个block的skip connection,那在该block前后的层的表征相似性则非常低.由此佐证了skip connection对ViT层间表征相似性的作用!
ViT在higher layers的空间位置信息是否仍然保留
知道了前面ViT与ResNet的一些差别后,还想知道它的空间信息在较高的层是否仍然保留,这对transformer是否可以干除了图像分类之外的事很重要,如目标检测
我们通过对最后一个block的token与最开始输入的patch token进行比较(计算不同位置的CKA值),然后做heatmap,可以看出它们的相似性,即空间位置信息是否被high layers保留.图示如下:
显然,ViT的空间位置信息被保留下来了,而且所选的单个token与最开始的对应的patch相似性最强,而边缘部分的token也是如此,但其与其他边缘位置相似性也很高.可以看出ViT对空间位置信息有保留!相较之下,ResNet则体现不出来,按作者的说法就是significantly weaker的位置信息
然后作者还对ResNet为啥会位置信息保留得如此薄弱进行了实验,认为是分类所采用的方法导致的,ViT采用的是一个单独的token->CLS token,对原位置信息本就不影响,而ResNet在训练时分类用的是全局平均池化(GAP),把信息都杂糅在一起了,哪里还有原来规整的位置信息
因此,就把ViT里面的CLS token去掉,通过GAP来做分类,结果说明了确实是GAP的原因,图示如下:
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!