网络架构设计:CNN based和Transformer based

2023-05-29 0 688

↑ 点选白字 网络架构设计:CNN based和Transformer based译者丨陀飞轮@chan(已许可)撰稿丨极市网络平台

极市编者按

 

责任撰稿主要就导出了CNN based和Transformer based的互联网体系结构,当中CNN based牵涉ResNet和BoTNet,Transformer based牵涉ViT和T2T-ViT。>>重新加入极市CV经验交流群,走在计算机系统听觉的前沿

从DETR到ViT等组织工作都校正了Transformer在计算机系统听觉应用领域的发展潜力,所以很大自然的就须要考量两个捷伊难题,影像的求逆,到底是CNN好却是Transformer好?

怎样有效率的紧密结合global和local重要信息,前段时间的两篇该文主要就分为了三个路径:CNN based和Transformer based。下列主要就导出呵呵CNN based和Transformer based的互联网体系结构,当中CNN based牵涉ResNet和BoTNetTransformer based牵涉ViT和T2T-ViT

互联网体系结构的相关联

网络架构设计:CNN based和Transformer based

BoTNet在ResNet的基础上将Bottlenneck的3×3传递函数换成MHSA,减少CNN based的互联网构架的global重要信息裂解潜能。T2T-ViT在ViT的基础上将patch的linear projection换成T2T,减少Transformer based的互联网构架的local重要信息裂解潜能。

ResNet&BoTNet

网络架构设计:CNN based和Transformer based

ResNet的结构设计,ResNet主要就由Bottleneck结构堆叠而成,一层Bottlenneck由1x1conv、3x3conv和1x1conv堆叠构成残差分支,然后和skip connect分支相加。BoTNet在Bottlenneck结构的基础上将中间的3x3conv换成MHSA结构,跟之间的Non-local等组织工作非常相似,本质上在CNN中引入global重要信息裂解。

网络架构设计:CNN based和Transformer based

MHSA结构如上图所示,代码如下。

class MHSA(nn.Module): def __init__(self, n_dims, width=14, height=14): super(MHSA, self).__init__() self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1) self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True) self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True) self.softmax = nn.Softmax(dim=-1) def forward(self, x): n_batch, C, width, height = x.size() q = self.query(x).view(n_batch, C, –1) k = self.key(x).view(n_batch, C, –1) v = self.value(x).view(n_batch, C, –1)content_content = torch.bmm(q.permute(0, 2, 1), k) content_position = (self.rel_h + self.rel_w).view(1, C, –1).permute(0, 2, 1)content_position = torch.matmul(content_position, q) energy = content_content + content_position attention = self.softmax(energy) out = torch.bmm(v, attention.permute(0, 2, 1)) out = out.view(n_batch, C, width, height)      return out

跟Transformer中的multi-head self-attention非常相似,区别在于MSHA将position encoding当成了spatial attention来处理,嵌入三个可学习的向量看成是横纵三个维度的空间注意力,然后将相加融合后的空间向量于q相乘得到contect-position(相当于是引入了空间先验),将content-position和content

ViT

ViT是第一篇纯粹的将Transformer用于影像特征抽取的该文。

网络架构设计:CNN based和Transformer based

Vision Transformer(ViT)将输入图片拆分为16×16个patches,每个patch做一次线性变换降维同时嵌入位置重要信息,然后送入Transformer。类似BERT[class]标记位的设置,ViT在Transformer输入序列前减少了两个额外可学习的[class]标记位,并且该位置的Transformer Encoder输出作为影像特征。

假设输入图片大小是256×256,打算分为64个patch,每个patch是32×32像素。

x = rearrange(img, b c (h p1) (w p2) -> b (h w) (p1 p2 c), p1=p, p2=p)# 将3072变成dim,假设是1024self.patch_to_embedding = nn.Linear(patch_dim, dim)x = self.patch_to_embedding(x)

这个写法是采用了爱因斯坦表达式,具体是采用了einops库实现,内部集成了各种算子,rearrange就是当中两个,非常高效。p就是patch大小,假设输入是b,3,256,256,则rearrange操作是先变成(b,3,8×32,8×32),最后变成(b,8×8,32x32x3)即(b,64,3072),将每张图片切分为64个小块,每个小块长度是32x32x3=3072,也就是说输入长度为64的影像序列,每个元素采用3072长度进行编码。考量到3072有点大,ViT使用linear projection对影像序列编码进行降维。

T2T-ViT

ViT虽然校正了Transformer在影像分类互联网体系结构的发展潜力,但是须要额外的大规模数据来进行pre-train,而在中等规模数据集如imagenet上效果却不理想。T2T-ViT引入了local的重要信息裂解来增强ViT局部结构建模的潜能,使得T2T-ViT在中等规模imagenet上训练能达到更高的精度。

网络架构设计:CNN based和Transformer based

在T2T模块中,先将输入影像软分割为小块,然后将其展开成两个tokens T0序列。然后tokens的长度在T2T模块中逐步减少(该文中使用两次迭代然后输出Tf)。后续跟ViT基本上一致。

网络架构设计:CNN based和Transformer based

一次迭代T2T结构由re-structurization和soft split构成,re-structurization将一维序列reshape成二维影像, soft split对二维影像进行滑窗操作,拆分为重叠块。

以token transformer为例,先将输入影像拆分为7×7的重叠块,然后通过token transformer,进行块内的global重要信息裂解,然后通过re-structurization和soft split进行token重组和拆分为3×3的重叠块,得到长度更短的token序列,重复迭代两次,最后linear projection进一步降低token序列长度。

class T2T_module(nn.Module): “”” Tokens-to-Token encoding module “”” def __init__(self, img_size=224, in_chans=3, embed_dim=768, token_dim=64): super().__init__()self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) self.project = nn.Linear(token_dim * 3 * 3, embed_dim)self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 soft split, stride are 4,2,2 seperately def forward(self, x): # step0: soft split x = self.soft_split0(x).transpose(1, 2) # iteration1: restricturization/reconstruction x = self.attention1(x) B, new_HW, C = x.shape x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) # iteration1: soft split x = self.soft_split1(x).transpose(1, 2) # iteration2: restricturization/reconstruction x = self.attention2(x) B, new_HW, C = x.shape x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) # iteration2: soft splitx = self.soft_split2(x).transpose(1, 2) # final tokens x = self.project(x) return x

总结

1.global和local重要信息裂解的关系

global和local应该相互补充来同时balance 速度和精度,同时提升速度和精度的上限

2.CNN based和Transformer based的关系,CNN based 和 Transformer based哪个好

本质上是互联网体系结构是以CNN为主好却是Transformer为主好的难题,CNN为主却是将输入当成二维的影像信号来处理,Transformer为主则将输入当成一维的序列信号来处理,所以想要研究清楚CNN为主好却是Transformer为主好的难题,须要去探索哪种输入信号更加具有优势,之前不少研究都表明CNN的padding可能透露了位置重要信息,而Transformer因为没有归纳偏见,须要减少position encoding来引入位置重要信息。CNN为主和Transformer为主各有优劣,目前来看暂无定论,且看后续发展。

Reference

[1] Deep Residual Learning for Image Recognition[2] Bottleneck Transformers for Visual Recognition[3] An image is worth 16×16 words: Transformers for image recognition at scale[4] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

推荐阅读

传递函数神经互联网与Transformer紧密结合,东南大学提出视频帧合成新构架

2020 神经互联网构架搜索(NAS)最新技术综述

何恺明团队新作!深度学习互联网构架新视角:通过相关图表达理解神经互联网

添加极市小助手微信(ID : cvmart2),备注:姓名-学校/公司-研究路径-城市(如:小极-北大-目标检测-深圳),即可申请重新加入极市目标检测/影像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/影像增强/OCR/视频理解等经验交流群:月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企听觉开发者互动交流~网络架构设计:CNN based和Transformer based△长按添加极市小助手网络架构设计:CNN based和Transformer based最新CV干货觉得有用麻烦给个在看啦~网络架构设计:CNN based和Transformer based

相关文章

发表评论
暂无评论
官方客服团队

为您解决烦忧 - 24小时在线 专业服务