if steps_per_loop >= _MIN_SUMMARY_STEPS:
// Only writes summary when the stats are collected sufficiently over
// enough steps.
train_summary_writer = tf.summary.create_file_writer(
os.path.join(model_dir, "summaries/train"))
else:
train_summary_writer = None
def _replicated_step(inputs):
Replicated training step.
inputs, labels = inputs
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
// For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(labels, model_outputs)
@tf.function
def train_steps(iterator, steps):
Performs distributed training steps in a loop.
Args:
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
if not isinstance(steps, tf.Tensor):
raise ValueError("steps should be an Tensor. Python object may cause "
"retracing.")
for _ in tf.range(steps):
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def train_single_step(iterator):
Performs a distributed training step.
Args:
iterator: the distributed iterator of training datasets.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def test_step(iterator):
Calculates evaluation metrics on distributed devices.
def _test_step_fn(inputs):
Replicated accuracy calculation.
inputs, labels = inputs
model_outputs = model(inputs, training=False)
eval_metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
def _run_evaluation(current_training_step, test_iterator):
Runs validation steps and aggregate metrics.
for _ in range(eval_steps):
test_step(test_iterator)
eval_metric_value = _float_metric_value(eval_metric)
logging.info("Step: [%d] Validation metric = %f", current_training_step,
eval_metric_value)
with eval_summary_writer.as_default():
tf.summary.scalar(
eval_metric.name, eval_metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
Runs custom callbacks at the start of every step.
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_begin(batch)
def _run_callbacks_on_batch_end(batch):
Runs custom callbacks at the end of every step.
if not custom_callbacks:
return
for callback in custom_callbacks:
callback.on_batch_end(batch)
// Training loop starts here.
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint_file:
logging.info(
"Checkpoint file %s found and restoring from "
"checkpoint", latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file)
logging.info("Loading from checkpoint file completed")
current_step = optimizer.iterations.numpy()
checkpoint_name = "ctl_step_{step}.ckpt"
while current_step < total_training_steps:
// Training loss/metric are taking average over steps inside micro
// training loop. We reset the their values before each round.
train_loss_metric.reset_states()
if train_metric:
train_metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
// Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
// TODO(zongweiz): merge with train_steps once tf.while_loop
// GPU performance bugs are fixed.
train_single_step(train_iterator)
else:
// Converts steps to a Tensor to avoid tf.function retracing.
train_steps(train_iterator,
tf.convert_to_tensor(steps, dtype=tf.int32))
_run_callbacks_on_batch_end(current_step)
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
// Updates training logging.
training_status = "Train Step: %d/%d / loss = %s" % (
current_step, total_training_steps, train_loss)
if train_metric:
train_metric_value = _float_metric_value(train_metric)
training_status += " training metric = %f" % train_metric_value
else:
train_metric_value = None
logging.info(training_status)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric_value:
tf.summary.scalar(
train_metric.name, train_metric_value, step=current_step)
train_summary_writer.flush()
// Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
// To avoid repeated model saving, we do not save after the last
// step of training.
if current_step < total_training_steps: