wuvin
Always take risks!
Toggle navigation
wuvin
主页
实验室的搬砖生活
机器学习
公开的学术内容
公开的其他内容
About Me
归档
标签
友情链接
ZYQN
ihopenot
enigma_aw
hzwer
杨宗翰
WGAN
2019-12-07 15:34:13
1080
0
0
wuvin
# Abstract 解决了GAN难训的问题。 # Key Points 普通GAN有的问题及原因: * D训练太好时,G没法学习。 采用CrossEntropy,当D收敛到最优时,G的梯度就没了。原因是当生成图片的分布与真实分布几乎没有交集时, KL散度或JS散度基本为定值。   当$f$为L-Lipschitz时,  所以  * 对应到网络结构中的解决方案就是—— * 判别器最后一层去掉激活函数 * 生成器和判别器的loss不取log(相对CrossEntropy而言) * 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c(尝试保证L-Lipschitz) * -D_Loss可以反应训练进程 以下是实验的其他结果 * 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行(无理论证明,后续发展证明这不影响) # 后续发展 * 在[Improved Training of Wasserstein GANs](https://arxiv.org/abs/1704.00028)中,改进了限制L-Lipschitz的方法,从weight clipping变到gradient penalty. * weight clipping和W度量导致网络所以权值要么是c要么是-c,从而影响网络表达能力。以及会导致梯度消失或者爆炸。 * gradient penalty通过设计一个loss来限制Lipschitz,采用$ReLU(|\nabla_x D(x)|_p - K)$加入损失函数来解决。 *  * 其中$x\sim X$要从样本空间采样,过于困难。所以只考虑$P_r,P_g$即中间分布。故$x = t x_r + (1-t) x_g$来线性插值采样。 * 由于gradient penalty,故避免使用BN带来样本间的依赖关系。Paper推荐使用Layer Normalization代替BN。 ```python def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) dydx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=weight, retain_graph=True, create_graph=True, only_inputs=True)[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) return torch.mean((dydx_l2norm-1)**2) ```
上一篇:
Paper Readings
下一篇:
小米运动APP 注册地无法使用NFC问题
0
赞
1080 人读过
新浪微博
微信
腾讯微博
QQ空间
人人网
提交评论
立即登录
, 发表评论.
没有帐号?
立即注册
0
条评论
More...
文档导航
没有帐号? 立即注册