// is called for a new set of variables. If that happens inside a function
// annotated with tf.function it throws an exception, so call it once here.
zero_grads = [tf.zeros(v.shape) for v in variables]
self._tf_optimizer.apply_gradients(zip(zero_grads, variables))
if var_key not in self._gradient_fn_for_vars:
self._gradient_fn_for_vars[var_key] = self._create_gradient_fn(variables)
apply_gradient_for_batch = self._gradient_fn_for_vars[var_key]