cdf_g0, cdf_g1 = minmax(cdf)
denom = (cdf_g1 - cdf_g0)
denom = jnp.where(denom < eps, 1., denom)
t = (u - cdf_g0) / denom
z_samples = bins_g0 + t * (bins_g1 - bins_g0)
// Prevent gradient from backprop-ing through samples
After Change
// avoids NaNs when the input is zeros or small, but has no effect otherwise.
eps = 1e-5
weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
padding = jnp.maximum(0, eps - weight_sum)
weights += padding / weights.shape[-1]
weight_sum += padding
// Compute the PDF and CDF for each weight vector.