// For masking inputs with offline random masks.
if not self.opt.isTrain and self.opt.offline_testing:
self.mask_global = Image.open(os.path.join("masks", os.path.splitext(os.path.basename(self.image_paths[0]))[0]+"_mask.png"))
self.mask_global = transforms.ToTensor()(self.mask_global).unsqueeze(0).type_as(real_A).byte()
self.set_latent_mask(self.mask_global)
After Change
real_B = input["B"].to(self.device)
// directly load mask offline
// TODO: make masks variant each image in a batch
self.mask_global = input["M"].to(self.device).byte()
self.mask_global = self.mask_global.narrow(0,0,1).narrow(1,0,1)
// create mask online