常用loss合集

TV loss

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class TVLoss(nn.Module):
def __init__(self,TVLoss_weight=1):
super(TVLoss,self).__init__()
self.TVLoss_weight = TVLoss_weight

def forward(self,x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:,:,1:,:]) #channel * (h-1) * w
count_w = self._tensor_size(x[:,:,:,1:]) #channel * h * (w-1)
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

def _tensor_size(self,t):
return t.size()[1]*t.size()[2]*t.size()[3]

个人理解 *2 的意思是TV是四个方向的,所以两个方向计算下来需要 * 2