// respect to each input dimension should be taken into account,
// but we ignore the differences and assume that the curvature is
// uniform with respect to all the input dimensions.
directions /= xp.sqrt(xp.square(directions).sum())
// Small elements in the direction vector leads to instability on
// gradients comparison.
// In order to avoid that, absolute values are capped at
// 0.1 / sqrt(N) at minimum, where N is the number of elements.
// Other elements are scaled uniformly so that the total L2 norm
// will remain 1.
min_d = 0.1 / math.sqrt(size)
is_small = min_d > xp.abs(directions)
is_large = xp.logical_not(is_small)
n_small = is_small.sum()
sq_large = xp.square(directions[is_large]).sum()
scale = xp.sqrt((1 - n_small * min_d ** 2) / sq_large)
// Cap small elements.
directions[is_small] = xp.sign(directions[is_small]) * min_d
After Change
// Small elements in the direction vector leads to instability on
// gradients comparison. In order to avoid that, make absolute values
// at least 0.1 / sqrt(size).
sq_directions = xp.square(directions)
sq_norm = sq_directions.sum()
return xp.copysign(
// Weighted quadratic mean of
// abs(directions / norm) and xp.full(size, 1 / xp.sqrt(size)),
// where norm = xp.sqrt(sq_norm)
xp.sqrt(
(0.99 / sq_norm) * sq_directions
+ 0.01 / size),
directions)