self.opt = opt
self.isTrain = opt.isTrain
// specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ["G_GAN", "G_L1", "D_real", "D_fake"]
// specify the images you want to save/display. The program will call base_model.get_current_visuals
self.visual_names = ["real_A", "fake_B", "real_B"]
// specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ["G", "D"]
else: // during test time, only load Gs
self.model_names = ["G"]
// batchsize should be 1 for mask_global
self.mask_global = torch.ByteTensor(1, 1, \
opt.fineSize, opt.fineSize)
// Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)
self.mask_global.zero_()
self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1
self.mask_type = opt.mask_type
self.gMask_opts = {}
self.fixed_mask = opt.fixed_mask if opt.mask_type == "center" else 0
if opt.mask_type == "center":
assert opt.fixed_mask == 1, "Center mask must be fixed mask!"
if self.mask_type == "random":
res = 0.06 // the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large
density = 0.25
MAX_SIZE = 10000
maxPartition = 30
low_pattern = torch.rand(1, 1, int(res*MAX_SIZE), int(res*MAX_SIZE)).mul(255)
pattern = F.functional.interpolate(low_pattern, (MAX_SIZE, MAX_SIZE), mode="bilinear").detach()
low_pattern = None
pattern.div_(255)
pattern = torch.lt(pattern,density).byte() // 25% 1s and 75% 0s
pattern = torch.squeeze(pattern).byte()
print("...Random pattern generated")
self.gMask_opts["pattern"] = pattern
self.gMask_opts["MAX_SIZE"] = MAX_SIZE
self.gMask_opts["fineSize"] = opt.fineSize
self.gMask_opts["maxPartition"] = maxPartition
self.gMask_opts["mask_global"] = self.mask_global
self.mask_global = util.create_gMask(self.gMask_opts) // create an initial random mask.
self.wgan_gp = False
// added for wgan-gp
if opt.gan_type == "wgan_gp":
self.gp_lambda = opt.gp_lambda
self.ncritic = opt.ncritic
self.wgan_gp = True
if len(opt.gpu_ids) > 0:
self.use_gpu = True
self.mask_global = self.mask_global.to(self.device)
// load/define networks
// self.ng_innerCos_list is the constraint list in netG inner layers.
// self.ng_mask_list is the mask list constructing shift operation.
self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_dropout, opt.init_type, self.gpu_ids, opt.init_gain) // add opt, we need opt.shift_sz and other stuffs
if self.isTrain:
use_sigmoid = False
if opt.gan_type == "vanilla":
use_sigmoid = True // only vanilla GAN using BCECriterion
// don"t use cGAN
self.netD = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt.init_gain)
if self.isTrain:
self.old_lr = opt.lr
// define loss functions
self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
self.criterionL1 = torch.nn.L1Loss()
// initialize optimizers
self.schedulers = []
self.optimizers = []
if self.wgan_gp:
opt.beta1 = 0
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
else:
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))
if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
def set_input(self, input):
real_A = input["A"].to(self.device)
real_B = input["B"].to(self.device)