Histogram Transformer

Restoring Images in Adverse Weather Conditions via Histogram Transformer

主要贡献

  1. 之前transformer都是在local windows上计算,但是雨,雾等,往往全局的有污渍的区域有相同的特征,需要相同的处理,因此本文transformer是在全局做的,并且patch不是按照像素块求的,而是按照相似像素分的。也就是按照数字大小排序,相似的分成一个bin,组成一个patch。

  2. 这个sort在新版本的pytorch里面可导,就是会记录sort路径,然后按照路径反传。

  3. 算是Unet类型的,不能完全说transformer作为backbone。

attention 模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def forward(self, x):
b,c,h,w = x.shape
x_sort, idx_h = x[:,:c//2].sort(-2)
x_sort, idx_w = x_sort.sort(-1)
x[:,:c//2] = x_sort
# 这个x的 hw 基本上等于原图的hw c 中 一半是有序的c(最小的都在左上角这种)
# 一半是对应本来位置的c feature
# q1,k1,q2,k2,v
qkv = self.qkv_dwconv(self.qkv(x))
#经过conv2d(dim,5dim)扩展到qkv:b,5c,h,w
q1,k1,q2,k2,v = qkv.chunk(5, dim=1) # q:b,c,h,w


v, idx = v.view(b,c,-1).sort(dim=-1) # 按照v大小排序 b,c,hw
q1 = torch.gather(q1.view(b,c,-1), dim=2, index=idx)
k1 = torch.gather(k1.view(b,c,-1), dim=2, index=idx)
q2 = torch.gather(q2.view(b,c,-1), dim=2, index=idx)
k2 = torch.gather(k2.view(b,c,-1), dim=2, index=idx)
# kq 也按照 v的顺序排序

out1 = self.reshape_attn(q1, k1, v, True)
out2 = self.reshape_attn(q2, k2, v, False)

# 两种attention做完,再恢复原来的形状
out1 = torch.scatter(out1, 2, idx, out1).view(b,c,h,w)
out2 = torch.scatter(out2, 2, idx, out2).view(b,c,h,w)
out = out1 * out2# 属实不知道这里为什么是乘法
out = self.project_out(out)
out_replace = out[:,:c//2]# 再恢复原来的形状
out_replace = torch.scatter(out_replace, -1, idx_w, out_replace)
out_replace = torch.scatter(out_replace, -2, idx_h, out_replace)
out[:,:c//2] = out_replace
return out

def reshape_attn(self, q, k, v, ifBox):
# self.factor = num_heads
# q: b c hw
# 此时c是特征 hw是全部token的数量,后续还会进行token合并
b, c = q.shape[:2]
q, t_pad = self.pad(q, self.factor)
k, t_pad = self.pad(k, self.factor)
v, t_pad = self.pad(v, self.factor)
hw = q.shape[-1] // self.factor
# 此时 hw的理解是特征维度 (c factor) 是分成的bin的数量
# 比较绕也不确定正确,这里就是不同的分bin方式
shape_ori = "b (head c) (factor hw)" if ifBox else "b (head c) (hw factor)"
shape_tar = "b head (c factor) hw"

q = rearrange(q, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
k = rearrange(k, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
v = rearrange(v, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = self.softmax_1(attn, dim=-1)
out = (attn @ v)
out = rearrange(out, '{} -> {}'.format(shape_tar, shape_ori), factor=self.factor, hw=hw, b=b, head=self.num_heads)
out = self.unpad(out, t_pad)
return out

这是两种reshape的含义。大概一个是选取large-scale(面积上相近?)的信息,一个是选取频率相近的信息。