| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 
 | class TVLoss(nn.Module):def __init__(self,TVLoss_weight=1):
 super(TVLoss,self).__init__()
 self.TVLoss_weight = TVLoss_weight
 
 def forward(self,x):
 batch_size = x.size()[0]
 h_x = x.size()[2]
 w_x = x.size()[3]
 count_h = self._tensor_size(x[:,:,1:,:]) #channel * (h-1) * w
 count_w = self._tensor_size(x[:,:,:,1:]) #channel * h * (w-1)
 h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
 w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
 return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
 
 def _tensor_size(self,t):
 return t.size()[1]*t.size()[2]*t.size()[3]
 
 |