// fix values when xyz = 0
if (xyz.view(-1, 3).norm(2, -1) == 0).nonzero().numel() > 0: // this `if` is not needed with version 1.0 of pytorch
val = torch.cat([spherical_harmonics(0, xyz.flatten()[0], 321) if l == 0 else xyz.new_zeros(2 * l + 1) for l in order]) // [m]
out[:, xyz.norm(2, -1) == 0] = val.view(-1, 1)
return out
After Change
out = spherical_harmonics(order, alpha, beta) // [m, ...]
// fix values when xyz = 0
val = torch.cat([xyz.new_tensor([1 / math.sqrt(4 * math.pi)]) if l == 0 else xyz.new_zeros(2 * l + 1) for l in order]) // [m]
out[:, xyz.norm(2, -1) == 0] = val.view(-1, 1)
return out