Before Change
np.array of shape (batch_size, n_dcc)
proportions = np.sum(data, axis=2) / data.shape[2]
shannon = -np.sum(proportions * np.log(proportions + 1e-9), axis=2) // axis 3 is now 2
return shannon
After Change
total_obs = np.sum(data, axis=2, keepdims=True)
with np.errstate(divide="ignore", invalid="ignore"):
proportions = np.nan_to_num(total_obs / np.sum(total_obs, axis=3, keepdims=True))
proportions[proportions == 0] = 1
shannon = (-np.sum(proportions * np.log(proportions), axis=3))[:, :, 0]
return shannon

