1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| class TVLoss(nn.Module): def __init__(self,TVLoss_weight=1): super(TVLoss,self).__init__() self.TVLoss_weight = TVLoss_weight
def forward(self,x): batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] count_h = self._tensor_size(x[:,:,1:,:]) #channel * (h-1) * w count_w = self._tensor_size(x[:,:,:,1:]) #channel * h * (w-1) h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
def _tensor_size(self,t): return t.size()[1]*t.size()[2]*t.size()[3]
|