8f6a6f153781d0908fb0904349aae844494026ea,models/shiftnet_model.py,ShiftNetModel,initialize,#ShiftNetModel#Any#,17

Before Change


            print("Loading pre-trained network!")
            self.load_network(self.netG, "G", opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, "D", opt.which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            // define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type, tensor=self.Tensor)
            self.criterionL1 = torch.nn.L1Loss()

            // initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.wgan_gp:
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

            print("---------- Networks initialized -------------")
            networks.print_network(self.netG)
            if self.isTrain:
                networks.print_network(self.netD)
            print("-----------------------------------------------")

    def set_input(self, input):
        input_A = input["A"]

After Change


        self.opt = opt
        self.isTrain = opt.isTrain
        // specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ["G_GAN", "G_L1", "D_real", "D_fake"]
        // specify the images you want to save/display. The program will call base_model.get_current_visuals
        self.visual_names = ["real_A", "fake_B", "real_B"]
        // specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ["G", "D"]
        else:  // during test time, only load Gs
            self.model_names = ["G"]


        // batchsize should be 1 for mask_global
        self.mask_global = torch.ByteTensor(1, 1, \
                                 opt.fineSize, opt.fineSize)

        // Here we need to set an artificial mask_global(not to make it broken, so center hole is ok.)
        self.mask_global.zero_()
        self.mask_global[:, :, int(self.opt.fineSize/4) + self.opt.overlap : int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap,\
                                int(self.opt.fineSize/4) + self.opt.overlap: int(self.opt.fineSize/2) + int(self.opt.fineSize/4) - self.opt.overlap] = 1

        self.mask_type = opt.mask_type
        self.gMask_opts = {}
        self.fixed_mask = opt.fixed_mask if opt.mask_type == "center" else 0
        if opt.mask_type == "center":
            assert opt.fixed_mask == 1, "Center mask must be fixed mask!"

        if self.mask_type == "random":
            res = 0.06 // the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large
            density = 0.25
            MAX_SIZE = 10000
            maxPartition = 30
            low_pattern = torch.rand(1, 1, int(res*MAX_SIZE), int(res*MAX_SIZE)).mul(255)
            pattern = F.functional.interpolate(low_pattern, (MAX_SIZE, MAX_SIZE), mode="bilinear").detach()
            low_pattern = None
            pattern.div_(255)
            pattern = torch.lt(pattern,density).byte()  // 25% 1s and 75% 0s
            pattern = torch.squeeze(pattern).byte()
            print("...Random pattern generated")
            self.gMask_opts["pattern"] = pattern
            self.gMask_opts["MAX_SIZE"] = MAX_SIZE
            self.gMask_opts["fineSize"] = opt.fineSize
            self.gMask_opts["maxPartition"] = maxPartition
            self.gMask_opts["mask_global"] = self.mask_global
            self.mask_global = util.create_gMask(self.gMask_opts) // create an initial random mask.


        self.wgan_gp = False
        // added for wgan-gp
        if opt.gan_type == "wgan_gp":
            self.gp_lambda = opt.gp_lambda
            self.ncritic = opt.ncritic
            self.wgan_gp = True


        if len(opt.gpu_ids) > 0:
            self.use_gpu = True
            self.mask_global = self.mask_global.to(self.device)

        // load/define networks
        // self.ng_innerCos_list is the constraint list in netG inner layers.
        // self.ng_mask_list is the mask list constructing shift operation.
        self.netG, self.ng_innerCos_list, self.ng_shift_list = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
                                      opt.which_model_netG, opt, self.mask_global, opt.norm, opt.use_dropout, opt.init_type, self.gpu_ids, opt.init_gain) // add opt, we need opt.shift_sz and other stuffs
        if self.isTrain:
            use_sigmoid = False
            if opt.gan_type == "vanilla":
                use_sigmoid = True  // only vanilla GAN using BCECriterion
            // don"t use cGAN
            self.netD = networks.define_D(opt.input_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, opt.init_gain)

        if self.isTrain:
            self.old_lr = opt.lr
            // define loss functions
            self.criterionGAN = networks.GANLoss(gan_type=opt.gan_type).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            // initialize optimizers
            self.schedulers = []
            self.optimizers = []
            if self.wgan_gp:
                opt.beta1 = 0
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            else:
                self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
                self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        if not self.isTrain or opt.continue_train:
            self.load_networks(opt.which_epoch)

        self.print_networks(opt.verbose)

    def set_input(self, input):
        real_A = input["A"].to(self.device)
        real_B = input["B"].to(self.device)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 17

Instances


Project Name: Zhaoyi-Yan/Shift-Net_pytorch
Commit Name: 8f6a6f153781d0908fb0904349aae844494026ea
Time: 2018-12-03
Author: yanzhaoyi@outlook.com
File Name: models/shiftnet_model.py
Class Name: ShiftNetModel
Method Name: initialize


Project Name: Zhaoyi-Yan/Shift-Net_pytorch
Commit Name: 8f6a6f153781d0908fb0904349aae844494026ea
Time: 2018-12-03
Author: yanzhaoyi@outlook.com
File Name: models/shiftnet_model.py
Class Name: ShiftNetModel
Method Name: initialize


Project Name: richzhang/colorization-pytorch
Commit Name: 843d68642bd15d5737e3eb39abd76c748d6b52e8
Time: 2018-04-19
Author: junyanzhu89@gmail.com
File Name: models/test_model.py
Class Name: TestModel
Method Name: initialize