V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
Waihinchan
V2EX  ›  机器学习

关于 pytorch TVloss 代码实现的一些疑惑

  •  
  •   Waihinchan · 2020-08-08 23:00:48 +08:00 · 2444 次点击
    这是一个创建于 1328 天前的主题,其中的信息可能已经有所发展或是发生改变。

    网上看到普遍的答案是这个

    class TVLoss(nn.Module):
        def __init__(self,TVLoss_weight=1):
            super(TVLoss,self).__init__()
            self.TVLoss_weight = TVLoss_weight
    
        def forward(self,x):
            batch_size = x.size()[0]
            h_x = x.size()[2]
            w_x = x.size()[3]
            count_h = self._tensor_size(x[:,:,1:,:])
            count_w = self._tensor_size(x[:,:,:,1:])
            h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
            w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
            return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
    
        def _tensor_size(self,t):
            return t.size()[1]*t.size()[2]*t.size()[3]
    

    这里给出的说的是β=2,且不支持变更. 所以按照这里给出的公式 https://blog.csdn.net/yexiaogu1104/article/details/88395475 β/2, 当β=2 那就是 1 也就是不进行任何操作. 所以最后 return 这里为什么会返回一个 self.TVLoss_weight2, 为啥要2 呢..

    目前尚无回复
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   我们的愿景   ·   实用小工具   ·   975 人在线   最高记录 6543   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 25ms · UTC 20:27 · PVG 04:27 · LAX 13:27 · JFK 16:27
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.