def __call__(self, x, y=None):
channel_means = x.mean(1)
for i in range(2,x.dim()):
channel_means = channel_means.mean(i)
channel_means = channel_means.expand_as(x)
x = (x - channel_means) * self.var + channel_means
if y is not None:
After Change
def __call__(self, x, y=None):
channel_means = x.mean(1).mean(2)
channel_means = channel_means.expand_as(x)
x = th.clamp((x - channel_means) * self.value + channel_means,0,1)
if y is not None:
channel_means = y.mean(1).mean(2).expand_as(y)
y = th.clamp((y - channel_means) * self.value + channel_means,0,1)
return x, y
return x