Swin Transformer and SwinIR

图像多头注意力基础Vit

也就是固定patch的感受野是16*16像素,然后根据输入分辨率的变化,patch总数变多。

transformer分成计算MHA 和MLP

MHA把 196个 768维度patch投影到qkv 每个patch独立,但是采用相同nn.linear(dim, dim * 3, bias=qkv_bias)参数计算qkv,所有patch的qk再计算相关度的到196*196 的相关度矩阵attn,attn再和v计算,得到196 * 768维度的结果,这个结果是每个patch都考虑到了全局其他patch影响后的结果。

 class Attention(nn.Module):
 """Multi-head self-attention (MHSA)"""
 def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
 super().__init__()
 self.num_heads = num_heads
 head_dim = dim // num_heads
 self.scale = head_dim ** -0.5 # 缩放因子 (√d_k)
    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # Q, K, V 投影
    self.attn_drop = nn.Dropout(attn_drop)
    self.proj = nn.Linear(dim, dim)  # 输出投影
    self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
    B, N, C = x.shape  # [batch_size, num_patches, embed_dim]
    # 对于每一个patch,把qkv计算出来(同patch 一样维度)
    # 再reshape成 B patch
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)  # [B, num_heads, N, head_dim]

    attn = (q @ k.transpose(-2, -1)) * self.scale  # QK^T / √d_k
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # 合并多头
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

kqv 一开始是从 B N embed_dim 计算成 B N 3embed_dim,然后分成 B N 3 head embed_dim//3; 再之后分成q,k,v。再之后交换维度顺序,变成B heads N head_dim。每一个head独立计算。

总计算量 不算batch,假设一共有N个patch MHA中 patchpatch + patch * dimdim 在MLP模块为patch * dim*dim

Swin transformer

主要贡献

  1. 由于自然语言中,一个单词就是一个token,但是在视觉问题中,一个像素还是3*3个像素块还是更大的像素块被看作一个token不确定,所以需要层次化的,多尺度的结构,类似Unet那样,不同的感受野当作一个token,计算attention。所以本文提出了层次化的结构。

  2. 图片往往都是高清的,全局做attention计算量太大,因此提出了滑动窗口做attention,并且提出了shift windows,更好处理边缘的patch。

整体网络结构

逐步下采样,窗口的多头attention和shift 窗口的attention交替进行。

其他细节

相对位置编码

SwinIR

多个任务共用 shallow feature 和 deep feature 提取,不同任务不容的reconstruction模块,backbone是RSTL