sq_distance.mul_(2 * math.pi ** 2) // sq_distance = -2*pi^2*(x-z)^2
res = torch.zeros(n, m)
for weight, mean, scale in zip(mixture_weights, mixture_means, mixture_scales):
weight = weight.expand(n, m)
mean = mean.expand(n, m)
scale = scale.expand(n, m)
After Change
sq_distance.mul_(2 * math.pi ** 2) // sq_distance = -2*pi^2*(x-z)^2
res = x1.data.new(n, m).zero_()
for weight, mean, scale in zip(mixture_weights, mixture_means, mixture_scales):
weight = weight.expand(n, m)
mean = mean.expand(n, m)
scale = scale.expand(n, m)