raise ValueError("Expected %s gradients, but got %d" % (
len(self._gradients), len(gradients)))
for accum_gradient, gradient in zip(self._get_replica_gradients(), gradients):
accum_gradient.assign_add(gradient)
self._accum_steps.assign_add(1)
After Change
// In a replica context, we want to accumulate gradients on each replica
// without synchronization, so we directly assign the value of the
// current replica.
for accum_gradient, gradient in zip(_get_replica_local_variables(self._gradients), gradients):
accum_gradient.assign_add(gradient)
_get_replica_local_variables(self._accum_steps).assign_add(1)