def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
// Defines the generator that consists of Resnet blocks between a few
// downsampling/upsampling operations.
After Change
if self.gan_mode in ["lsgan", "vanilla"]:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == "wgangp":
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
def cal_gradient_penalty(netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0):
calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028