图像多头注意力基础Vit
也就是固定patch的感受野是16*16像素,然后根据输入分辨率的变化,patch总数变多。
transformer分成计算MHA 和MLP
MHA把 196个 768维度patch投影到qkv 每个patch独立,但是采用相同nn.linear(dim, dim * 3, bias=qkv_bias)参数计算qkv,所有patch的qk再计算相关度的到196*196 的相关度矩阵attn,attn再和v计算,得到196 * 768维度的结果,这个结果是每个patch都考虑到了全局其他patch影响后的结果。
class Attention(nn.Module):
"""Multi-head self-attention (MHSA)"""
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5 # 缩放因子 (√d_k)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # Q, K, V 投影
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim) # 输出投影
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape # [batch_size, num_patches, embed_dim]
# 对于每一个patch,把qkv计算出来(同patch 一样维度)
# 再reshape成 B patch
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # [B, num_heads, N, head_dim]
attn = (q @ k.transpose(-2, -1)) * self.scale # QK^T / √d_k
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # 合并多头
x = self.proj(x)
x = self.proj_drop(x)
return x
kqv 一开始是从 B N embed_dim 计算成 B N 3embed_dim,然后分成 B N 3 head embed_dim//3; 再之后分成q,k,v。再之后交换维度顺序,变成B heads N head_dim。每一个head独立计算。
总计算量 不算batch,假设一共有N个patch MHA中 patchpatch + patch * dimdim 在MLP模块为patch * dim*dim
Swin transformer
主要贡献
由于自然语言中,一个单词就是一个token,但是在视觉问题中,一个像素还是3*3个像素块还是更大的像素块被看作一个token不确定,所以需要层次化的,多尺度的结构,类似Unet那样,不同的感受野当作一个token,计算attention。所以本文提出了层次化的结构。
图片往往都是高清的,全局做attention计算量太大,因此提出了滑动窗口做attention,并且提出了shift windows,更好处理边缘的patch。
整体网络结构
逐步下采样,窗口的多头attention和shift 窗口的attention交替进行。
其他细节
相对位置编码
SwinIR
多个任务共用 shallow feature 和 deep feature 提取,不同任务不容的reconstruction模块,backbone是RSTL