| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 
 | def forward(self, x):b,c,h,w = x.shape
 x_sort, idx_h = x[:,:c//2].sort(-2)
 x_sort, idx_w = x_sort.sort(-1)
 x[:,:c//2] = x_sort
 
 
 
 qkv = self.qkv_dwconv(self.qkv(x))
 
 q1,k1,q2,k2,v = qkv.chunk(5, dim=1)
 
 
 v, idx = v.view(b,c,-1).sort(dim=-1)
 q1 = torch.gather(q1.view(b,c,-1), dim=2, index=idx)
 k1 = torch.gather(k1.view(b,c,-1), dim=2, index=idx)
 q2 = torch.gather(q2.view(b,c,-1), dim=2, index=idx)
 k2 = torch.gather(k2.view(b,c,-1), dim=2, index=idx)
 
 
 out1 = self.reshape_attn(q1, k1, v, True)
 out2 = self.reshape_attn(q2, k2, v, False)
 
 
 out1 = torch.scatter(out1, 2, idx, out1).view(b,c,h,w)
 out2 = torch.scatter(out2, 2, idx, out2).view(b,c,h,w)
 out = out1 * out2
 out = self.project_out(out)
 out_replace = out[:,:c//2]
 out_replace = torch.scatter(out_replace, -1, idx_w, out_replace)
 out_replace = torch.scatter(out_replace, -2, idx_h, out_replace)
 out[:,:c//2] = out_replace
 return out
 
 def reshape_attn(self, q, k, v, ifBox):
 
 
 
 b, c = q.shape[:2]
 q, t_pad = self.pad(q, self.factor)
 k, t_pad = self.pad(k, self.factor)
 v, t_pad = self.pad(v, self.factor)
 hw = q.shape[-1] // self.factor
 
 
 shape_ori = "b (head c) (factor hw)" if ifBox else "b (head c) (hw factor)"
 shape_tar = "b head (c factor) hw"
 
 q = rearrange(q, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
 k = rearrange(k, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
 v = rearrange(v, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
 q = torch.nn.functional.normalize(q, dim=-1)
 k = torch.nn.functional.normalize(k, dim=-1)
 attn = (q @ k.transpose(-2, -1)) * self.temperature
 attn = self.softmax_1(attn, dim=-1)
 out = (attn @ v)
 out = rearrange(out, '{} -> {}'.format(shape_tar, shape_ori), factor=self.factor, hw=hw, b=b, head=self.num_heads)
 out = self.unpad(out, t_pad)
 return out
 
 |