Yin的笔记本

vuePress-theme-reco Howard Yin    2021 - 2025
Yin的笔记本 Yin的笔记本

Choose mode

  • dark
  • auto
  • light
Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • Linux
  • MyIdeas
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Python
  • Nvidia
  • Rust
  • Tex
  • Shell
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 讨论
  • 边缘计算
  • 量子信息技术
Tag
TimeLine
About
查看源码
author-avatar

Howard Yin

303

Article

153

Tag

Home
Category
  • CNCF
  • Docker
  • namespaces
  • Kubernetes
  • Kubernetes对象
  • Linux
  • MyIdeas
  • Revolution
  • WebRTC
  • 云计算
  • 人工智能
  • 分布式
  • 图像处理
  • 图形学
  • 微服务
  • 数学
  • OJ笔记
  • 博弈论
  • 形式语言与自动机
  • 数据库
  • 服务器运维
  • 编程语言
  • C
  • Git
  • Go
  • Java
  • JavaScript
  • Python
  • Nvidia
  • Rust
  • Tex
  • Shell
  • Vue
  • 视频编解码
  • 计算机网络
  • SDN
  • 论文笔记
  • 讨论
  • 边缘计算
  • 量子信息技术
Tag
TimeLine
About
查看源码
  • Swin Transformer解析

    • 最上层
      • 类型和初始化
      • 推断过程
    • Patch Embedding PatchEmbed
      • 主体模块 BasicLayer
        • Patch Merging PatchMerging
          • 初始化
          • 推断过程
        • 主体模块 SwinTransformerBlock
          • Window Partition/Reverse
          • 初始化
          • 推断过程
          • 总结
        • 核心模块WindowAttention
          • 初始化
          • 推断过程
          • 总结
        • 总结

        Swin Transformer解析

        vuePress-theme-reco Howard Yin    2021 - 2025

        Swin Transformer解析


        Howard Yin 2022-06-12 12:22:02 人工智能TransformerSwinTransformer神经网络注意力机制

        # 最上层

        先看看最上层:

        # 类型和初始化

        初始化的部分有很详细的参数注释,不多解释。

        class SwinTransformer(nn.Module):
            r""" Swin Transformer
                A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
                  https://arxiv.org/pdf/2103.14030
            Args:
                img_size (int | tuple(int)): Input image size. Default 224
                patch_size (int | tuple(int)): Patch size. Default: 4
                in_chans (int): Number of input image channels. Default: 3
                num_classes (int): Number of classes for classification head. Default: 1000
                embed_dim (int): Patch embedding dimension. Default: 96
                depths (tuple(int)): Depth of each Swin Transformer layer.
                num_heads (tuple(int)): Number of attention heads in different layers.
                window_size (int): Window size. Default: 7
                mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
                qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
                qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
                drop_rate (float): Dropout rate. Default: 0
                attn_drop_rate (float): Attention dropout rate. Default: 0
                drop_path_rate (float): Stochastic depth rate. Default: 0.1
                norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
                ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
                patch_norm (bool): If True, add normalization after patch embedding. Default: True
                use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
            """
        
            def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                         embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                         window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                         drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                         norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                         use_checkpoint=False, **kwargs):
                super().__init__()
        
                self.num_classes = num_classes
                self.num_layers = len(depths)
                self.embed_dim = embed_dim
                self.ape = ape
                self.patch_norm = patch_norm
                self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
                self.mlp_ratio = mlp_ratio
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33
        34
        35
        36
        37
        38
        39
        40

        初始化了 Patch Embedding 这个PatchEmbed函数应该就是 Patch Embedding 算法,后面重点介绍

                # split image into non-overlapping patches
                self.patch_embed = PatchEmbed(
                    img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
                    norm_layer=norm_layer if self.patch_norm else None)
                num_patches = self.patch_embed.num_patches
                patches_resolution = self.patch_embed.patches_resolution
                self.patches_resolution = patches_resolution
        
                # absolute position embedding
                if self.ape:
                    self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
                    trunc_normal_(self.absolute_pos_embed, std=.02)
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12

        用了Dropout

                self.pos_drop = nn.Dropout(p=drop_rate)
        
                # stochastic depth
                dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        
        1
        2
        3
        4

        初始化模型主干,BasicLayer函数就是 Swin Transformer 主体模块,后面重点介绍

                # build layers
                self.layers = nn.ModuleList()
                for i_layer in range(self.num_layers):
                    layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                                       input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                         patches_resolution[1] // (2 ** i_layer)),
                                       depth=depths[i_layer],
                                       num_heads=num_heads[i_layer],
                                       window_size=window_size,
                                       mlp_ratio=self.mlp_ratio,
                                       qkv_bias=qkv_bias, qk_scale=qk_scale,
                                       drop=drop_rate, attn_drop=attn_drop_rate,
                                       drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                                       norm_layer=norm_layer,
                                       downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                                       use_checkpoint=use_checkpoint)
                    self.layers.append(layer)
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17

        接下来是LayerNorm和线性Head,不多介绍;最后还有个初始化权重的,也不必多说

                self.norm = norm_layer(self.num_features)
                self.avgpool = nn.AdaptiveAvgPool1d(1)
                self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        
                self.apply(self._init_weights)
        
            def _init_weights(self, m):
                if isinstance(m, nn.Linear):
                    trunc_normal_(m.weight, std=.02)
                    if isinstance(m, nn.Linear) and m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.constant_(m.bias, 0)
                    nn.init.constant_(m.weight, 1.0)
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14

        # 推断过程

        1. Patch Embedding
        2. Dropout
        3. Swin Transformer 主干
        4. LayerNorm
        5. 平均值池化
        6. Head
            def forward_features(self, x):
                x = self.patch_embed(x)
                if self.ape:
                    x = x + self.absolute_pos_embed
                x = self.pos_drop(x)
        
                for layer in self.layers:
                    x = layer(x)
        
                x = self.norm(x)  # B L C
                x = self.avgpool(x.transpose(1, 2))  # B C 1
                x = torch.flatten(x, 1)
                return x
        
            def forward(self, x):
                x = self.forward_features(x)
                x = self.head(x)
                return x
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18

        # Patch Embedding PatchEmbed

        class PatchEmbed(nn.Module):
            r""" Image to Patch Embedding
            Args:
                img_size (int): Image size.  Default: 224.
                patch_size (int): Patch token size. Default: 4.
                in_chans (int): Number of input image channels. Default: 3.
                embed_dim (int): Number of linear projection output channels. Default: 96.
                norm_layer (nn.Module, optional): Normalization layer. Default: None
            """
        
            def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
                super().__init__()
                img_size = to_2tuple(img_size)
                patch_size = to_2tuple(patch_size)
                patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
                self.img_size = img_size
                self.patch_size = patch_size
                self.patches_resolution = patches_resolution
                self.num_patches = patches_resolution[0] * patches_resolution[1]
        
                self.in_chans = in_chans
                self.embed_dim = embed_dim
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22

        Patch Embedding实际上就是一个卷积,输入patch_sizexpatch_sizexin_chans的Patch,输出1x1xembed_dim,然后把stride设成和patch_size一样大,于是就是对不重叠的Patch进行操作。

                self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
                if norm_layer is not None:
                    self.norm = norm_layer(embed_dim)
                else:
                    self.norm = None
        
        1
        2
        3
        4
        5

        计算过程也没啥好讲的,执行Patch Embedding的那个卷积之后结果显然是img_size/patch_sizeximg_size/patch_sizexembed_dim的矩阵,用flatten展平成img_size^2/patch_size^2xembed_dim矩阵即可。

            def forward(self, x):
                B, C, H, W = x.shape
                # FIXME look at relaxing size constraints
                assert H == self.img_size[0] and W == self.img_size[1], \
                    f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
                x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
                if self.norm is not None:
                    x = self.norm(x)
                return x
        
        1
        2
        3
        4
        5
        6
        7
        8
        9

        # 主体模块 BasicLayer

        class BasicLayer(nn.Module):
            """ A basic Swin Transformer layer for one stage.
            Args:
                dim (int): Number of input channels.
                input_resolution (tuple[int]): Input resolution.
                depth (int): Number of blocks.
                num_heads (int): Number of attention heads.
                window_size (int): Local window size.
                mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
                qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
                qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
                drop (float, optional): Dropout rate. Default: 0.0
                attn_drop (float, optional): Attention dropout rate. Default: 0.0
                drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
                norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
                downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
                use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
            """
        
            def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                         mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                         drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
        
                super().__init__()
                self.dim = dim
                self.input_resolution = input_resolution
                self.depth = depth
                self.use_checkpoint = use_checkpoint
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28

        又套了一个SwinTransformerBlock还是没有触及核心,后面重点解析这个SwinTransformerBlock。

                # build blocks
                self.blocks = nn.ModuleList([
                    SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                         num_heads=num_heads, window_size=window_size,
                                         shift_size=0 if (i % 2 == 0) else window_size // 2,
                                         mlp_ratio=mlp_ratio,
                                         qkv_bias=qkv_bias, qk_scale=qk_scale,
                                         drop=drop, attn_drop=attn_drop,
                                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                         norm_layer=norm_layer)
                    for i in range(depth)])
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11

        这个downsample是外面来的函数,从最上层的调用看,传进来的是一个PatchMerging函数,后面也得解析一下。

                # patch merging layer
                if downsample is not None:
                    self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
                else:
                    self.downsample = None
        
        1
        2
        3
        4
        5

        推断过程就是先SwinTransformerBlock再downsample。

            def forward(self, x):
                for blk in self.blocks:
                    if self.use_checkpoint:
                        x = checkpoint.checkpoint(blk, x)
                    else:
                        x = blk(x)
                if self.downsample is not None:
                    x = self.downsample(x)
                return x
        
        1
        2
        3
        4
        5
        6
        7
        8
        9

        # Patch Merging PatchMerging

        这PatchMerging是Swin Transformer的核心创新点之一,虽然很短但很重要。

        # 初始化

        class PatchMerging(nn.Module):
            r""" Patch Merging Layer.
            Args:
                input_resolution (tuple[int]): Resolution of input feature.
                dim (int): Number of input channels.
                norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
            """
        
            def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
                super().__init__()
                self.input_resolution = input_resolution
                self.dim = dim
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12

        这是LayerNorm和线性映射层,后面的Patch Merging操作会将4个Patch叠在一起,所以这里的输入是4xdim;然后按照论文讲的,Patch Merging将4个Patch叠在一起后再用线性层将通道数减半,所以这里的输出是2xdim

                self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
                self.norm = norm_layer(4 * dim)
        
        1
        2

        有个注意点是到这里Patch已经是一个一维向量了,而这里的nn.Linear只与输入的dim有关而与Patch数量无关,根据调用链,这个dim=int(embed_dim * 2 ** i_layer),embed_dim是一维Patch Embedding的长度,输出的i_layer是第几层。看上图可以知道,特征图channel数每过一个Swin Transformer层就翻倍,所以这个dim实际上也只和Patch Embedding的长度有关,这意味着每个Patch都是经过相同的线性计算。由于图像大小固定,Patch数量也是固定的,这里实际上可以给每个Patch单用一个线性层参数,但作者没有这么用。

        # 推断过程

            def forward(self, x):
                """
                x: B, H*W, C
                """
                H, W = self.input_resolution
                B, L, C = x.shape
                assert L == H * W, "input feature has wrong size"
                assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
        
                x = x.view(B, H, W, C)
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10

        最核心的就是下面这个,这个操作就是把相邻的2x2共4个Patch叠在一起,然后用线性层将通道数减半。

                x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
                x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
                x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
                x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
                x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
                x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
        
                x = self.norm(x)
                x = self.reduction(x)
        
                return x
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11

        # 主体模块 SwinTransformerBlock

        # Window Partition/Reverse

        在开始讲主体模块前,有必要先讲Window Partition/Reverse机制,即函数window_partition和window_reverse。

        window_partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 NxHxWxC, 划分成 num_windows*Bxwindow_sizexwindow_sizexC,其中 num_windows = H*W / (window_size*window_size),即窗口的个数。而window_reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。

        window_partition函数长这样:

        def window_partition(x, window_size):
            """
            Args:
                x: (B, H, W, C)
                window_size (int): window size
            Returns:
                windows: (num_windows*B, window_size, window_size, C)
            """
            B, H, W, C = x.shape
            x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
            windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
            return windows
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12

        可以看到:

        1. 首先是view将输入的mask中的长宽维扩出来一个window_size维度,即从BxHxWxC扩成BxH/window_sizexwindow_sizexW/window_sizexwindow_sizexC
        2. 然后permute在W/window_size和window_size维度进行转置,变成BxH/window_sizexW/window_sizexwindow_sizexwindow_sizexC
        3. view和permute都不改变原始矩阵中的值,而用contiguous是把转置后的矩阵拷贝一份
        4. 最后的view(-1, window_size, window_size, C)把BxH/window_sizexW/window_size这几个维度都抹了,只剩下一堆window_sizexwindow_sizexC
        def window_reverse(windows, window_size, H, W):
            B = int(windows.shape[0] / (H * W / window_size / window_size))
            x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
            return x
        
        1
        2
        3
        4
        5

        # 初始化

        class SwinTransformerBlock(nn.Module):
            r""" Swin Transformer Block.
            Args:
                dim (int): Number of input channels.
                input_resolution (tuple[int]): Input resulotion.
                num_heads (int): Number of attention heads.
                window_size (int): Window size.
                shift_size (int): Shift size for SW-MSA.
                mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
                qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
                qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
                drop (float, optional): Dropout rate. Default: 0.0
                attn_drop (float, optional): Attention dropout rate. Default: 0.0
                drop_path (float, optional): Stochastic depth rate. Default: 0.0
                act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
                norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
            """
        
            def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                         mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                         act_layer=nn.GELU, norm_layer=nn.LayerNorm):
                super().__init__()
                self.dim = dim
                self.input_resolution = input_resolution
                self.num_heads = num_heads
                self.window_size = window_size
                self.shift_size = shift_size
                self.mlp_ratio = mlp_ratio
                if min(self.input_resolution) <= self.window_size:
                    # if window size is larger than input resolution, we don't partition windows
                    self.shift_size = 0
                    self.window_size = min(self.input_resolution)
                assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21
        22
        23
        24
        25
        26
        27
        28
        29
        30
        31
        32
        33

        这个WindowAttention显然就是论文里讲的在窗口内的Attention,后面细讲。

                self.norm1 = norm_layer(dim)
                self.attn = WindowAttention(
                    dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
                    qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        
        1
        2
        3
        4

        下面这是个带Dropout的全连接层,不多讲。

                self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
                self.norm2 = norm_layer(dim)
                mlp_hidden_dim = int(dim * mlp_ratio)
                self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        
        1
        2
        3
        4

        # Shift Windows核心代码

        所谓Shift Windows核心就是下面这段代码。

                if self.shift_size > 0:
        
        1
        # Attention mask for SW-MSA

        首先是“calculate attention mask for SW-MSA”?

                    # calculate attention mask for SW-MSA
        
                    H, W = self.input_resolution
                    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
                    h_slices = (slice(0, -self.window_size),
                                slice(-self.window_size, -self.shift_size),
                                slice(-self.shift_size, None))
                    w_slices = (slice(0, -self.window_size),
                                slice(-self.window_size, -self.shift_size),
                                slice(-self.shift_size, None))
                    cnt = 0
                    for h in h_slices:
                        for w in w_slices:
                            img_mask[:, h, w, :] = cnt
                            cnt += 1
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15

        从两个for循环看,这个img_mask应该是一个标记,标记了每个待滑动区域的编号。实际上等价于:

        img_mask[:,            0:-window_size,            0:-window_size, :] = 0
        img_mask[:,            0:-window_size, -window_size: -shift_size, :] = 1
        img_mask[:,            0:-window_size, -shift_size :            , :] = 2
        img_mask[:, -window_size: -shift_size,            0:-window_size, :] = 3
        img_mask[:, -window_size: -shift_size, -window_size: -shift_size, :] = 4
        img_mask[:, -window_size: -shift_size, -shift_size :            , :] = 5
        img_mask[:, -shift_size :            ,            0:-window_size, :] = 6
        img_mask[:, -shift_size :            , -window_size: -shift_size, :] = 7
        img_mask[:, -shift_size :            , -shift_size :            , :] = 8
        
        1
        2
        3
        4
        5
        6
        7
        8
        9

        标记大概是这样:

        根据调用链进行查找,可以发现window_size是用户指定的,而input_resolution和shift_size都是计算得到:

        input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)),
        
        1

        其中i_layer是BasicLayer的编号,说明这个input_resolution逐层减半,和论文说的一样。

        shift_size=0 if (i % 2 == 0) else window_size // 2
        
        1

        其中i是SwinTransformerBlock在BasicLayer中的编号,说明这个shift操作是隔层进行的,对应论文中的这个图:

        即每两个SwinTransformerBlock组成一个基本单元,中间有一个Shift操作。

        这也说明了W-MSA实际上只是SW-MSA少了一个Shift操作。

        接下来是一个window_partition操作:

                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        
        1
        2

        然后给填了一下数据?-100和0?

                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
                else:
                    attn_mask = None
        
                self.register_buffer("attn_mask", attn_mask)
        
        1
        2
        3
        4
        5
        6

        可以看到是对上面那个mask进行的操作。但是这到底是什么意思呢?

        # 推断过程

            def forward(self, x):
                H, W = self.input_resolution
                B, L, C = x.shape
                assert L == H * W, "input feature has wrong size"
        
                shortcut = x
                x = self.norm1(x)
                x = x.view(B, H, W, C)
        
        1
        2
        3
        4
        5
        6
        7
        8

        这里一个shift操作,直接用torch.roll完成。

                # cyclic shift
                if self.shift_size > 0:
                    shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
                else:
                    shifted_x = x
        
        1
        2
        3
        4
        5

        torch.roll的原理是这样:

        那显然这里对应论文里的这个Cyclic Shift操作:

        但这个和前面的img_mask又有什么关系?

        接下来是window_partition把输入特征图切成windows,不用多讲。

                # partition windows
                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
                x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
        
        1
        2
        3

        再然后就是应用WindowAttention:

                # W-MSA/SW-MSA
                attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        
        1
        2

        再然后就是把切成windows的特征图还原回来:

                # merge windows
                attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
        
        1
        2
        3

        再然后就是把移位后的特征图还原回来:

                # reverse cyclic shift
                if self.shift_size > 0:
                    x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
                else:
                    x = shifted_x
                x = x.view(B, H * W, C)
        
        1
        2
        3
        4
        5
        6

        最后处理一下Dropout:

                x = shortcut + self.drop_path(x)
        
                # FFN
                x = x + self.drop_path(self.mlp(self.norm2(x)))
        
                return x
        
        1
        2
        3
        4
        5
        6

        # 总结

        总结一下,这里显然是Swin Transformer中各种Shift Windows的核心代码,但是看完了还是没能理解这个Shift Windows到底是在干什么。且看看后面的代码在回头总结吧。

        # 核心模块WindowAttention

        # 初始化

        class WindowAttention(nn.Module):
            r""" Window based multi-head self attention (W-MSA) module with relative position bias.
            It supports both of shifted and non-shifted window.
            Args:
                dim (int): Number of input channels.
                window_size (tuple[int]): The height and width of the window.
                num_heads (int): Number of attention heads.
                qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
                qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
                attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
                proj_drop (float, optional): Dropout ratio of output. Default: 0.0
            """
        
            def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        
                super().__init__()
                self.dim = dim
                self.window_size = window_size  # Wh, Ww
                self.num_heads = num_heads
                head_dim = dim // num_heads
                self.scale = qk_scale or head_dim ** -0.5
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        17
        18
        19
        20
        21

        一大段的相对位置编码:

                # define a parameter table of relative position bias
                self.relative_position_bias_table = nn.Parameter(
                    torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
                    
                # get pair-wise relative position index for each token inside the window
                coords_h = torch.arange(self.window_size[0])
                coords_w = torch.arange(self.window_size[1])
                coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
                coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
                relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
                relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
                relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
                relative_coords[:, :, 1] += self.window_size[1] - 1
                relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
                relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
                self.register_buffer("relative_position_index", relative_position_index)
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16

        几个线性层和Dropout啥的,不必多讲

                self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
                self.attn_drop = nn.Dropout(attn_drop)
                self.proj = nn.Linear(dim, dim)
                self.proj_drop = nn.Dropout(proj_drop)
        
                trunc_normal_(self.relative_position_bias_table, std=.02)
                self.softmax = nn.Softmax(dim=-1)
        
        1
        2
        3
        4
        5
        6
        7

        # 推断过程

            def forward(self, x, mask=None):
                """
                Args:
                    x: input features with shape of (num_windows*B, N, C)
                    mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
                """
                B_, N, C = x.shape
                qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
                q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        
                q = q * self.scale
                attn = (q @ k.transpose(-2, -1))
        
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12

        这里加上了相对位置编码:

                relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                    self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
                relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
                attn = attn + relative_position_bias.unsqueeze(0)
        
        1
        2
        3
        4

        然后加上了mask?按照前面讲的,这个mask里面是-100和0,所以这里是什么意思?

        将mask加到attention的计算结果,并进行softmax。mask的值设置为-100,softmax后就会忽略掉对应的值

                if mask is not None:
                    nW = mask.shape[0]
                    attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
                    attn = attn.view(-1, self.num_heads, N, N)
                    attn = self.softmax(attn)
                else:
                    attn = self.softmax(attn)
        
        1
        2
        3
        4
        5
        6
        7

        处理Dropout:

                attn = self.attn_drop(attn)
        
        1

        处理最后一个线性层和Dropout:

                x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
                x = self.proj(x)
                x = self.proj_drop(x)
                return x
        
        1
        2
        3
        4

        # 总结

        前面那个mask居然是在这里用的?还是没懂。

        # 总结

        还是看看大佬总结的吧《图解Swin Transformer》

        帮助我们改善此页面!
        创建于: 2022-06-12 07:36:34

        更新于: 2022-06-12 12:22:21