// Case 2: Entries of z that are very small
if z_is_small.sum() > 0:
numerator = torch.Tensor([0.5641895835477550741]).expand_as(z[z_is_small])
denominator = torch.Tensor([1.0]).expand_as(z[z_is_small])
for r_i in self.r:
numerator = -z[z_is_small].mul(numerator.div(math.sqrt(2))) + r_i
for q_i in self.q:
denominator = -z[z_is_small].mul(denominator.div(math.sqrt(2))) + q_i
After Change
// Case 2: Entries of z that are very small
if z_is_small.sum() > 0:
z_where_z_is_small = z.masked_select(z_is_small)
numerator = z.new([0.5641895835477550741]).expand_as(z_where_z_is_small)
denominator = z.new([1.0]).expand_as(z_where_z_is_small)