// Runs training steps.
train_steps(train_ds_iterator,
tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
current_step += train_steps_per_eval
train_loss = train_loss_metric.result().numpy().astype(float)
logging.info("Train Step: %d/%d / loss = %s",
current_step, flags_obj.train_steps, train_loss)
checkpoint_name = checkpoint.save(
os.path.join(
flags_obj.model_dir,
"ctl_step_{}.ckpt".format(current_step)))
logging.info("Saved checkpoint to %s", checkpoint_name)
else:
if self.use_tpu:
raise NotImplementedError(
"Keras model.fit on TPUs is not implemented.")
history = model.fit(
train_ds,
initial_epoch=current_iteration,
epochs=current_iteration + 1,
steps_per_epoch=train_steps_per_eval,
callbacks=callbacks,
// If TimeHistory is enabled, progress bar would be messy. Increase
// the verbose level to get rid of it.
verbose=(2 if flags_obj.enable_time_history else 1))
current_step += train_steps_per_eval
logging.info("Train history: {}".format(history.history))
print("End train iteration at global step:{}".format(current_step))
if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval()