loss = self._loss_fn
var_key = None
if variables is not None:
var_key = tuple(v.experimental_ref() for v in variables)
// The optimizer creates internal variables the first time apply_gradients()
// is called for a new set of variables. If that happens inside a function
// annotated with tf.function it throws an exception, so call it once here.
zero_grads = [tf.zeros(v.shape) for v in variables]
self._tf_optimizer.apply_gradients(zip(zero_grads, variables))
if var_key not in self._gradient_fn_for_vars:
self._gradient_fn_for_vars[var_key] = self._create_gradient_fn(variables)
apply_gradient_for_batch = self._gradient_fn_for_vars[var_key]
time1 = time.time()
// Main training loop.