温州微网站公司,五大建设内容,万网做网站吗,精选网页设计该篇文章#xff0c;是我解析 Swin transformer 论文原理#xff08;结合pytorch版本代码#xff09;所记#xff0c;图片来源于源paper或其他相应博客。
代码也非原始代码#xff0c;而是从代码里摘出来的片段#xff0c;配上简单数据#xff0c;以便理解。
当然是我解析 Swin transformer 论文原理结合pytorch版本代码所记图片来源于源paper或其他相应博客。
代码也非原始代码而是从代码里摘出来的片段配上简单数据以便理解。
当然也可能因为设置数据不当造成误解请多指教。
刚写了一部分。先发布。希望多多指正。 Figure 1. (a) The proposed Swin Transformer builds hierarchical feature maps by merging image patches (shown in gray) in deeper layers , and has linear computation complexity to input image size due to computation of self-attention only within each local window (shown in red). It can thus serve as a general-purpose backbone for both image classification and dense recognition tasks. (b) In contrast, previous vision Transformers produce feature maps of a single low resolution and have quadratic computation complexity to input image size due to computation of self attention globally.
模型结构图 Figure 3. (a) The architecture of a Swin Transformer (Swin-T); (b) two successive Swin Transformer Blocks (notation presented with Eq. (3)). W-MSA and SW-MSA are multi-head self attention modules with regular and shifted windowing configurations, respectively.
Stage 1 – Patch Embedding
It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT.
Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values.
In our implementation, we use a patch size of 4×4 and thus the feature dimension of each patch is 4×4×3 48.channel–3
A linear embedding layer is applied on this raw-valued feature to project it to an arbitrary dimension (denoted as C). 这个表述linear embedding layer我感觉不太准确但是后半部分比较准确哈哈将channel–3变成了96.
Several Transformer blocks with modified self-attention computation (Swin Transformer blocks) are applied on these patch tokens.
The Transformer blocks maintain the number of tokens (H/4 × W/4), and together with the linear embedding are referred to as “Stage 1”.
代码
以下代码来自于model.py
class PatchEmbed(nn.Module):2D Image to Patch Embeddingtime : 2024/12/17import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as Fclass PatchEmbed(nn.Module):2D Image to Patch Embeddingdef __init__(self, patch_size4, in_c3, embed_dim96, norm_layerNone):super().__init__()patch_size (patch_size, patch_size)self.patch_size patch_sizeself.in_chans in_cself.embed_dim embed_dimself.proj nn.Conv2d(in_c, embed_dim, kernel_sizepatch_size, stridepatch_size)self.norm norm_layer(embed_dim) if norm_layer else nn.Identity()def forward(self, x):_, _, H, W x.shape# padding# 如果输入图片的HW不是patch_size的整数倍需要进行paddingpad_input (H % self.patch_size[0] ! 0) or (W % self.patch_size[1] ! 0)if pad_input:# to pad the last 3 dimensions,# (W_left,W_right, H_top,H_bottom, C_front,C_back)x F.pad(x,(0, self.patch_size[1] - W % self.patch_size[1],0, self.patch_size[0] - H % self.patch_size[0],0, 0))# 下采样patch_size倍x self.proj(x)_, _, H, W x.shape# flatten: [B, C, H, W] - [B, C, HW]# transpose: [B, C, HW] - [B, HW, C]x x.flatten(2).transpose(1, 2)x self.norm(x)print(x.shape)# torch.Size([1, 3136, 96])# 224/4 * 224/4 3136return x, H, Wif __name__ __main__:img_path tulips.jpgimg Image.open(img_path)plt.imshow(img)# [N, C, H, W]print(img.size)# (500,375)#img_size 224data_transform transforms.Compose([transforms.Resize(int(img_size * 1.14)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])img data_transform(img)print(img.shape)# torch.Size([3, 224, 224])# expand batch dimensionimg torch.unsqueeze(img, dim0)print(img.shape)# torch.Size([1, 3, 224, 224])# split image into non-overlapping patchespatch_embed PatchEmbed(norm_layernn.LayerNorm)patch_embed(img)
Stage 2 – 3.2. Shifted Window based Self-Attention
Shifted window partitioning in successive blocks
The window-based self-attention module lacks connections across windows, which limits its modeling power.
To introduce cross-window connections while maintaining the efficient computation of non-overlapping windows, we propose a shifted window partitioning approach which alternates between two partitioning configurations in consecutive Swin Transformer blocks. 为了在保持非重叠窗口高效计算的同时引入跨窗口连接我们提出了一种移位窗口划分方法该方法在连续的Swin Transformer块中交替使用两种不同的划分配置。 Figure 2. In layer l (left), a regular window partitioning scheme is adopted, and self-attention is computed within each window. In the next layer l 1 (right), the window partitioning is shifted, resulting in new windows. The self-attention computation in the new windows crosses the boundaries of the previous windows in layer l, providing connections among them. 在新窗口中进行的自注意力计算跨越了第l层中先前窗口的边界从而在它们之间建立了连接。
Efficient batch computation for shifted configuration
An issue with shifted window partitioning is that it will result in more windows, and some of the windows will be smaller than M×M.
Here, we propose a more efficient batch computation approach by cyclic-shifting toward the top-left direction向左上方向循环移动, as illustrated in Figure 4.
这里的 more efficient是说相对于直观方法 padding—mask来说
A naive solution is to pad the smaller windows to a size of M×M and mask out the padded values when computing attention. Figure 4. Illustration of an efficient batch computation approach for self-attention in shifted window partitioning. After this shift, a batched window may be composed of several sub-windows that are not adjacent in the feature map, so a masking mechanism is employed to limit self-attention computation to within each sub-window. 在此转换之后批处理窗口可能由特征图中不相邻的几个子窗口组成因此采用掩蔽机制将自注意力计算限制在每个子窗口内。
With the cyclic-shift, the number of batched windows remains the same as that of regular window partitioning, and thus is also efficient. 通过循环移位批处理窗口的数量与常规窗口分区的数量保持不变因此也是高效的。 上图和叙述并不太直观找了相关资料一起分析 移动完成之后4是一个单独区域5、3为一组7、1为一组8、6、2、0为一组。
但5、3本身是两个图像的边缘混在一起计算不是乱了吗一起计算也没问题ViT也是全局计算的。
但Swin-Transformer为了防止这个问题在代码中使用了masked MSA这样就能够通过设置蒙板来隔绝不同区域的信息了。
源码中具体的方法就是将不计算的位置元素减去100。
这里需要注意的是在窗口数据进行滑动完之后需要将数据还原回去即挪回到原来的位置上。
代码
以下代码来自于model.py
def window_partition(x, window_size: int):将feature map按照window_size划分成一个个没有重叠的window主要思路是将feature转成 (num_windows*B, window_size*window_size, C)的shape把需要self-attn计算的window排列到第0维一次并行的qkv就可以了Args:x: (B, H, W, C)window_size (int): window size(M)Returns:windows: (num_windows*B, window_size, window_size, C)B, H, W, C x.shape# B224224C# B5656Cx x.view(B, H // window_size, window_size, W // window_size, window_size, C)# B327327C# B8787C# permute:# [B, H//Mh, Mh, W//Mw, Mw, C] -# [B, H//Mh, W//Mh, Mw, Mw, C]# B323277C# B8877C# view:# [B, H//Mh, W//Mw, Mh, Mw, C] -# [B*num_windows, Mh, Mw, C]# B*102477C# B*6477C# 32*32 1024# 224 / 7 32windows x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windows分析将 [B, C, 56, 56] 最后变成了[64B, C, 7, 7]原先的 B*C 张 56*56 的特征图最后变成了 B*64*C张7*7的特征
即我们有64B个样本每个样本包含C个7x7的通道。
注意window_size–M–7是每个window的大小7*7不是7*7个window我刚开始混淆了这一点。 class BasicLayer(nn.Module):# A basic Swin Transformer layer for one stage.def __init__(self, dim, depth, num_heads, window_size,mlp_ratio4., qkv_biasTrue, drop0., attn_drop0.,drop_path0., norm_layernn.LayerNorm, downsampleNone, use_checkpointFalse):super().__init__()self.dim dimself.depth depthself.window_size window_sizeself.use_checkpoint use_checkpointself.shift_size window_size // 2# 7//2 3# build blocksself.blocks nn.ModuleList([SwinTransformerBlock(dimdim,num_headsnum_heads,window_sizewindow_size,shift_size0 if (i % 2 0) else self.shift_size,mlp_ratiomlp_ratio,qkv_biasqkv_bias,dropdrop,attn_dropattn_drop,drop_pathdrop_path[i] if isinstance(drop_path, list) else drop_path,norm_layernorm_layer)for i in range(depth)])...# depth: 2, 2, 6, 2# 即第一层depth2, 有两个SwinTransformerBlockshift_size分别为03# 即第二层depth2, 有两个SwinTransformerBlockshift_size分别为03# 即第三层depth6, 有两个SwinTransformerBlockshift_size分别为# 030303# 即第四层depth2, 有两个SwinTransformerBlockshift_size分别为03def create_mask(self, x, H, W):# calculate attention mask for SW-MSAimport numpy as np
import torchH 7
W 7
window_size 7
shift_size 3Hp int(np.ceil(H / window_size)) * window_size
Wp int(np.ceil(W / window_size)) * window_size# 拥有和feature map一样的通道排列顺序方便后续window_partition
img_mask torch.zeros((1, Hp, Wp, 1))
# [1, Hp, Wp, 1]
print(img_mask, \n)h_slices (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(h_slices, \n)
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))w_slices (slice(0, -window_size),slice(-window_size, -shift_size),slice(-shift_size, None)
)
print(w_slices, \n)
# (slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))cnt 0
for h in h_slices:for w in w_slices:img_mask[:, h, w, :] cntcnt 1print(img_mask)import torchimg_mask torch.rand((2, 3))
print(img_mask)tensor([[0.7410, 0.6020, 0.5195],[0.9214, 0.2777, 0.8418]])attn_mask img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
print(attn_mask)tensor([[[ 0.0000, -0.1390, -0.2215],[ 0.1390, 0.0000, -0.0825],[ 0.2215, 0.0825, 0.0000]],[[ 0.0000, -0.6437, -0.0796],[ 0.6437, 0.0000, 0.5642],[ 0.0796, -0.5642, 0.0000]]])
print(img_mask.unsqueeze(1))tensor([[[0.7410, 0.6020, 0.5195]],[[0.9214, 0.2777, 0.8418]]])print(img_mask.unsqueeze(2))tensor([[[0.7410],[0.6020],[0.5195]],[[0.9214],[0.2777],[0.8418]]])上面那个代码需要根据下面这个代码对应着走shift_size–torch.roll()
class SwinTransformerBlock(nn.Module):# Swin Transformer Block....def forward(self, x, attn_mask):H, W self.H, self.WB, L, C x.shapeassert L H * W, input feature has wrong sizeshortcut xx self.norm1(x)x x.view(B, H, W, C)# pad feature maps to multiples of window size# 把feature map给pad到window size的整数倍pad_l pad_t 0pad_r (self.window_size - W % self.window_size) % self.window_sizepad_b (self.window_size - H % self.window_size) % self.window_size# 注意F.pad的顺序刚好是反着来的, 例如# x.shape (b, h, w, c)# x F.pad(x, (1, 1, 2, 2, 3, 3))# x.shape (b, h6, w4, c2)# 源码可能有误修改成下面的# x F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))x F.pad(x, (0, 0, pad_t, pad_b, pad_l, pad_r))_, Hp, Wp, _ x.shape# cyclic shiftif self.shift_size 0:# paper中滑动的size是窗口大小的/2向下取整# torch.roll以H,W的维度为例子负值往左上移动正值往右下移动。# 溢出的值在对角方向出现。即循环移动。shifted_x torch.roll(x, shifts(-self.shift_size, -self.shift_size), dims(1, 2))else:shifted_x xattn_mask None# partition windowsx_windows window_partition(shifted_x, self.window_size) # [nW*B, Mh, Mw, C]x_windows x_windows.view(-1, self.window_size * self.window_size, C) # [nW*B, Mh*Mw, C]# W-MSA/SW-MSAattn_windows self.attn(x_windows, maskattn_mask) # [nW*B, Mh*Mw, C]...其中torch.roll()方法简易示例如下
import torchx torch.randn(1, 4, 4, 3)
print(x, \n)shifted_x torch.roll(x, shifts(-3, -3), dims(1, 2))
print(shifted_x, \n)为了方便理解我更换了维度
import torchx torch.randn(1, 3, 7, 7)
print(x, \n)shifted_x torch.roll(x, shifts(-3, -3), dims(2, 3))
print(shifted_x, \n)Relative position bias Relative Position Bias通过给自注意力机制的输出加上一个与token相对位置相关的偏置项从而增强了模型对局部和全局信息的捕捉能力。
实现方式
1、构建相对位置索引
首先需要确定一个 window size 并在该窗口内计算token之间的相对位置。通过构建相对位置索引表relative position index table可以方便地查询任意两个token之间的相对位置。
2、可学习的偏置表
初始化一个与相对位置索引表大小相同的可学习参数表relative position bias table这些参数在训练过程中会被优化。根据相对位置索引从偏置表中查询对应的偏置值并将其加到自注意力机制的输出上。
3、计算过程
在自注意力机制的计算中通常会将 Q、K 和 V 进行矩阵乘法运算得到注意力得分。然后将 Relative Position Bias 加到注意力得分上再进行 softmax 运算最后与 V 相乘得到最终的输出。 代码
import torchcoords_h torch.arange(7)
coords_w torch.arange(7)a, b torch.meshgrid([coords_h, coords_w], indexingij)coords torch.stack(torch.meshgrid([coords_h, coords_w], indexingij))coords_flatten torch.flatten(coords, 1)
# [2, Mh*Mw]
# print(coords_flatten)tensor([
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6],
[0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6]
])# [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords relative_coords.permute(1, 2, 0).contiguous()
# [Mh*Mw, Mh*Mw, 2]tensor([[[ 0, 0],[ 0, -1],[ 0, -2],...,[-6, -4],[-6, -5],[-6, -6]],[[ 0, 1],[ 0, 0],[ 0, -1],...,[-6, -3],[-6, -4],[-6, -5]],[[ 0, 2],[ 0, 1],[ 0, 0],...,[-6, -2],[-6, -3],[-6, -4]],...,[[ 6, 4],[ 6, 3],[ 6, 2],...,[ 0, 0],[ 0, -1],[ 0, -2]],[[ 6, 5],[ 6, 4],[ 6, 3],...,[ 0, 1],[ 0, 0],[ 0, -1]],[[ 6, 6],[ 6, 5],[ 6, 4],...,[ 0, 2],[ 0, 1],[ 0, 0]]])relative_coords[:, :, 0] 6
# shift to start from 0relative_coords[:, :, 1] 6
relative_coords[:, :, 0] * 13relative_position_index relative_coords.sum(-1)
print(relative_position_index.shape)
# torch.Size([49, 49])# print(relative_position_index)tensor([[ 84, 83, 82, ..., 2, 1, 0],[ 85, 84, 83, ..., 3, 2, 1],[ 86, 85, 84, ..., 4, 3, 2],...,[166, 165, 164, ..., 84, 83, 82],[167, 166, 165, ..., 85, 84, 83],[168, 167, 166, ..., 86, 85, 84]])其中
relative_coords coords_flatten[:, :, None] - coords_flatten[:, None, :] 这行代码是用来计算一个二维坐标点集 coords_flatten 中所有点对之间的相对坐标或位移。
未对行列加乘操作之前的矩阵relative_coords 是一个形状为 (N, N, 2) 的数组其中 relative_coords[i, j, :] 表示从点 i 到点 j 的相对坐标或位移。 结合其他博客的分析
如图假设我们现在有一个window-size2的feature map 这里面如果用绝对位置来表示位置索引 然后如果用相对位置表示就会有4个情况但分别都是以自己为0, 0计算其他token的相对位置。 分别把4个相对位置展开得到4x4的矩阵如最下的矩阵所示。 请注意这里说的都是位置索引并不是最后的位置编码。因为后面我们会根据相对位置索引去取对应位置的参数。取出来的值才是相对位置编码。 源码中作者还将二维索引给转成了一维索引。如果直接将行列相加就变成一维了。但这样(0, 1)和(1, 0)得到的结果都是1这样肯定不行。来看看源码的做法怎么做的 首先所有行、列都加上M-1其次将所有的行索引乘上2M-1 最后行索引和列索引相加保证了相对位置关系也不会出现01 10 的现象了。
刚刚也说了之前计算的是相对位置索引并不是实际位置偏执参数。
真正使用到的数值需要从relative position bias table这个表的长度是等于2M-1X2M-1的。在代码中它是一个可学习参数。
import torch
from torch import nnwindow_size (7, 7)
num_heads 3relative_position_bias_table nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
)
print(relative_position_bias_table)
......
nn.init.trunc_normal_(self.relative_position_bias_table, std.02)Stage 3 – patch merging layers
To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper.
The first patch merging layer concatenates the features of each group of 2×2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features. 首个补丁合并层将每组2×2相邻补丁的特征进行拼接并在拼接后的4C维特征上应用一个线性层。
This reduces the number of tokens by a multiple of 2×24(2 ×downsampling of resolution), and the output dimension is set to 2C.
Swin Transformer blocks are applied afterwards for feature transformation, with the resolution kept at H/8 × W/8.
同样结合其他大神分析图展示如下 Related Work
Self-attention based backbone architectures
Instead of using sliding windows, we propose to shift windows between consecutive layers, which allows for a more efficient implementation in general hardware.
。。。。。
Cited link or paper name
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.https://blog.csdn.net/weixin_42392454/article/details/141395092