bd155f2d58357de72c4a80699fde34e2c515c540,codes/models/SFTGAN_ACD_model.py,SFTGAN_ACD_Model,__init__,#SFTGAN_ACD_Model#Any#,18
Before Change
super(SFTGAN_ACD_Model, self).__init__(opt)
train_opt = opt["train"]
self.input_L = self.Tensor()
self.input_H = self.Tensor()
self.input_seg = self.Tensor()
self.input_cat = self.Tensor().long() // category
// define networks and load pretrained models
self.netG = networks.define_G(opt) // G
if self.is_train:
self.netD = networks.define_D(opt) // D
self.netG.train()
self.netD.train()
self.load() // load G and D if needed
// define losses, optimizer and scheduler
if self.is_train:
// G pixel loss
if train_opt["pixel_weight"] > 0:
l_pix_type = train_opt["pixel_criterion"]
if l_pix_type == "l1":
self.cri_pix = nn.L1Loss()
elif l_pix_type == "l2":
self.cri_pix = nn.MSELoss()
else:
raise NotImplementedError("Loss type [%s] is not recognized." % l_pix_type)
self.l_pix_w = train_opt["pixel_weight"]
else:
print("Remove pixel loss.")
self.cri_pix = None
// G feature loss
if train_opt["feature_weight"] > 0:
l_fea_type = train_opt["feature_criterion"]
if l_fea_type == "l1":
self.cri_fea = nn.L1Loss()
elif l_fea_type == "l2":
self.cri_fea = nn.MSELoss()
else:
raise NotImplementedError("Loss type [%s] is not recognized." % l_fea_type)
self.l_fea_w = train_opt["feature_weight"]
else:
print("Remove feature loss.")
self.cri_fea = None
if self.cri_fea: // load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False)
// GD gan loss
self.cri_gan = GANLoss(train_opt["gan_type"], 1.0, 0.0, self.Tensor)
self.l_gan_w = train_opt["gan_weight"]
self.D_update_ratio = train_opt["D_update_ratio"] if train_opt["D_update_ratio"] else 1
self.D_init_iters = train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0
if train_opt["gan_type"] == "wgan-gp":
self.random_pt = Variable(self.Tensor(1, 1, 1, 1))
// gradient penalty loss
self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
self.l_gp_w = train_opt["gp_weigth"]
After Change
self.D_init_iters = train_opt["D_init_iters"] if train_opt["D_init_iters"] else 0
if train_opt["gan_type"] == "wgan-gp":
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
// gradient penalty loss
self.cri_gp = GradientPenaltyLoss(tensor=self.Tensor)
self.l_gp_w = train_opt["gp_weigth"]
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 12
Instances
Project Name: xinntao/BasicSR
Commit Name: bd155f2d58357de72c4a80699fde34e2c515c540
Time: 2018-06-16
Author: wxt1994@126.com
File Name: codes/models/SFTGAN_ACD_model.py
Class Name: SFTGAN_ACD_Model
Method Name: __init__
Project Name: Zhaoyi-Yan/Shift-Net_pytorch
Commit Name: 8f6a6f153781d0908fb0904349aae844494026ea
Time: 2018-12-03
Author: yanzhaoyi@outlook.com
File Name: models/shiftnet_model.py
Class Name: ShiftNetModel
Method Name: initialize
Project Name: xinntao/BasicSR
Commit Name: bd155f2d58357de72c4a80699fde34e2c515c540
Time: 2018-06-16
Author: wxt1994@126.com
File Name: codes/models/SRGAN_model.py
Class Name: SRGANModel
Method Name: __init__
Project Name: xinntao/BasicSR
Commit Name: bd155f2d58357de72c4a80699fde34e2c515c540
Time: 2018-06-16
Author: wxt1994@126.com
File Name: codes/models/SFTGAN_ACD_model.py
Class Name: SFTGAN_ACD_Model
Method Name: __init__