wuvin
Always take risks!
Toggle navigation
wuvin
主页
实验室的搬砖生活
机器学习
公开的学术内容
公开的其他内容
About Me
归档
标签
友情链接
ZYQN
ihopenot
enigma_aw
hzwer
杨宗翰
CYCADA: CYCLE-CONSISTENT ADVERSARIAL DOMAIN ADAPTATION
2019-09-19 15:52:33
2862
0
0
wuvin
# Introduction ICML2018, 建立分布A到分布B的一一映射,使得原来训练的模型在新环境下不会变差太多。 Changing of domain usually affect the model's performance , so they want to adapte the images frome source domain to target domain. ![title](https://leanote.com/api/file/getImage?fileId=5d83341cab6441313d0007fd) # Model Structure ![title](https://leanote.com/api/file/getImage?fileId=5d8334e2ab6441313d000804) 模型由四部分组成,第一部分是映射Cycle,第二部分是语义一致性,第三部分是生成对抗(包括图片生成的辨别$D_T$和经过$f_T$提取feature后的辨别$D_{feat}$),第四部分是分类。 The modle consist of four parts, one is mapping cycle, the second is semantic consistency, the third is GAN part of $D_T$ and $D_{feature \ after\ f_T}$, the fourth is the origin task. ![title](https://leanote.com/api/file/getImage?fileId=5d8338f5ab644133350008ea) ![title](https://leanote.com/api/file/getImage?fileId=5d83390dab644133350008ec) ![title](https://leanote.com/api/file/getImage?fileId=5d833923ab6441313d00084d) ![title](https://leanote.com/api/file/getImage?fileId=5d83394cab644133350008ee) ![title](https://leanote.com/api/file/getImage?fileId=5d833963ab6441313d000850) # Experiments ![title](https://leanote.com/api/file/getImage?fileId=5d89f8b7ab64411f16001b89) * Below is a model trained on GTA5 test on real image. ![title](https://leanote.com/api/file/getImage?fileId=5d833ab3ab6441313d000861) ![title](https://leanote.com/api/file/getImage?fileId=5d833aa7ab644133350008fd) The [FCNs](https://arxiv.org/abs/1612.02649) in the wild is method proposed on 2016. # Code ![title](https://leanote.com/api/file/getImage?fileId=5d89fa16ab6441211d001b00) ```python import numpy as np import torch import torch.nn as nn from torch.nn import init from .util import init_weights from .models import register_model, get_model @register_model('AddaNet') class AddaNet(nn.Module): "Defines and Adda Network." def __init__(self, num_cls=10, model='LeNet', src_weights_init=None, weights_init=None): super(AddaNet, self).__init__() self.name = 'AddaNet' self.base_model = model self.num_cls = num_cls self.cls_criterion = nn.CrossEntropyLoss() self.gan_criterion = nn.CrossEntropyLoss() self.setup_net() if weights_init is not None: self.load(weights_init) elif src_weights_init is not None: self.load_src_net(src_weights_init) else: raise Exception('AddaNet must be initialized with weights.') def forward(self, x_s, x_t): """Pass source and target images through their respective networks.""" score_s, x_s = self.src_net(x_s, with_ft=True) score_t, x_t = self.tgt_net(x_t, with_ft=True) if self.discrim_feat: d_s = self.discriminator(x_s) d_t = self.discriminator(x_t) else: d_s = self.discriminator(score_s) d_t = self.discriminator(score_t) return score_s, score_t, d_s, d_t def setup_net(self): """Setup source, target and discriminator networks.""" self.src_net = get_model(self.base_model, num_cls=self.num_cls) # LeNet or DRN or FCN self.tgt_net = get_model(self.base_model, num_cls=self.num_cls) # LeNet or DRN or FCN input_dim = self.num_cls self.discriminator = nn.Sequential( nn.Linear(input_dim, 500), nn.ReLU(), nn.Linear(500, 500), nn.ReLU(), nn.Linear(500, 2), ) # 这个D_{feature} 是不是太简陋了一点 self.image_size = self.src_net.image_size self.num_channels = self.src_net.num_channels def load(self, init_path): "Loads full src and tgt models." net_init_dict = torch.load(init_path) self.load_state_dict(net_init_dict) def load_src_net(self, init_path): """Initialize source and target with source weights.""" self.src_net.load(init_path) self.tgt_net.load(init_path) def save(self, out_path): torch.save(self.state_dict(), out_path) def save_tgt_net(self, out_path): torch.save(self.tgt_net.state_dict(), out_path) class Discriminator(nn.Module): # 这个D_T也很简陋啊 def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''): super().__init__() dim1 = 1024 if input_dim==4096 else 512 dim2 = int(dim1/2) self.D = nn.Sequential( nn.Conv2d(input_dim, dim1, 1), nn.Dropout2d(p=0.5), nn.ReLU(inplace=True), nn.Conv2d(dim1, dim2, 1), nn.Dropout2d(p=0.5), nn.ReLU(inplace=True), nn.Conv2d(dim2, output_dim, 1) ) if pretrained and weights_init is not None: self.load_weights(weights_init) def forward(self, x): d_score = self.D(x) return d_score def load_weights(self, weights): print('Loading discriminator weights') self.load_state_dict(torch.load(weights)) ```
上一篇:
ADDA: Adversarial Discriminative Domain Adaptation
下一篇:
Ladder Net及相关半监督方法
0
赞
2862 人读过
新浪微博
微信
腾讯微博
QQ空间
人人网
提交评论
立即登录
, 发表评论.
没有帐号?
立即注册
0
条评论
More...
文档导航
没有帐号? 立即注册