1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
| def forward(self, x): b,c,h,w = x.shape x_sort, idx_h = x[:,:c//2].sort(-2) x_sort, idx_w = x_sort.sort(-1) x[:,:c//2] = x_sort qkv = self.qkv_dwconv(self.qkv(x)) q1,k1,q2,k2,v = qkv.chunk(5, dim=1)
v, idx = v.view(b,c,-1).sort(dim=-1) q1 = torch.gather(q1.view(b,c,-1), dim=2, index=idx) k1 = torch.gather(k1.view(b,c,-1), dim=2, index=idx) q2 = torch.gather(q2.view(b,c,-1), dim=2, index=idx) k2 = torch.gather(k2.view(b,c,-1), dim=2, index=idx)
out1 = self.reshape_attn(q1, k1, v, True) out2 = self.reshape_attn(q2, k2, v, False)
out1 = torch.scatter(out1, 2, idx, out1).view(b,c,h,w) out2 = torch.scatter(out2, 2, idx, out2).view(b,c,h,w) out = out1 * out2 out = self.project_out(out) out_replace = out[:,:c//2] out_replace = torch.scatter(out_replace, -1, idx_w, out_replace) out_replace = torch.scatter(out_replace, -2, idx_h, out_replace) out[:,:c//2] = out_replace return out
def reshape_attn(self, q, k, v, ifBox): b, c = q.shape[:2] q, t_pad = self.pad(q, self.factor) k, t_pad = self.pad(k, self.factor) v, t_pad = self.pad(v, self.factor) hw = q.shape[-1] // self.factor shape_ori = "b (head c) (factor hw)" if ifBox else "b (head c) (hw factor)" shape_tar = "b head (c factor) hw"
q = rearrange(q, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads) k = rearrange(k, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads) v = rearrange(v, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = self.softmax_1(attn, dim=-1) out = (attn @ v) out = rearrange(out, '{} -> {}'.format(shape_tar, shape_ori), factor=self.factor, hw=hw, b=b, head=self.num_heads) out = self.unpad(out, t_pad) return out
|