.PER_HOST_V2)) // pylint: disable=line-too-long
if FLAGS.inference_with_all_cores:
resnet_classifier = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=resnet_model_fn,
config=config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
export_to_tpu=FLAGS.export_to_tpu,
experimental_exported_model_uses_all_cores=FLAGS
.inference_with_all_cores)
else:
resnet_classifier = tf.contrib.tpu.TPUEstimator(
use_tpu=FLAGS.use_tpu,
model_fn=resnet_model_fn,
config=config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
export_to_tpu=FLAGS.export_to_tpu)
assert FLAGS.precision == "bfloat16" or FLAGS.precision == "float32", (
"Invalid value for --precision flag; must be bfloat16 or float32.")
tf.logging.info("Precision: %s", FLAGS.precision)
use_bfloat16 = FLAGS.precision == "bfloat16"
// Input pipelines are slightly different (with regards to shuffling and
// preprocessing) between training and evaluation.
if FLAGS.bigtable_instance:
tf.logging.info("Using Bigtable dataset, table %s", FLAGS.bigtable_table)
select_train, select_eval = _select_tables_from_flags()
imagenet_train, imagenet_eval = [imagenet_input.ImageNetBigtableInput(
is_training=is_training,
use_bfloat16=use_bfloat16,
transpose_input=FLAGS.transpose_input,
selection=selection) for (is_training, selection) in
[(True, select_train),
(False, select_eval)]]
else:
if FLAGS.data_dir == FAKE_DATA_DIR:
tf.logging.info("Using fake dataset.")
else:
tf.logging.info("Using dataset: %s", FLAGS.data_dir)
imagenet_train, imagenet_eval = [
imagenet_input.ImageNetInput(
is_training=is_training,
data_dir=FLAGS.data_dir,
transpose_input=FLAGS.transpose_input,
cache=FLAGS.use_cache and is_training,
image_size=FLAGS.image_size,
num_parallel_calls=FLAGS.num_parallel_calls,
use_bfloat16=use_bfloat16) for is_training in [True, False]
]
steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size
eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size
if FLAGS.mode == "eval":
// Run evaluation when there"s a new checkpoint
for ckpt in evaluation.checkpoints_iterator(
FLAGS.model_dir, timeout=FLAGS.eval_timeout):
tf.logging.info("Starting to evaluate.")
try:
start_timestamp = time.time() // This time will include compilation time
eval_results = resnet_classifier.evaluate(
input_fn=imagenet_eval.input_fn,
steps=eval_steps,
checkpoint_path=ckpt)
elapsed_time = int(time.time() - start_timestamp)
tf.logging.info("Eval results: %s. Elapsed seconds: %d",
eval_results, elapsed_time)
// Terminate eval job when final checkpoint is reached
current_step = int(os.path.basename(ckpt).split("-")[1])
if current_step >= FLAGS.train_steps:
tf.logging.info(
"Evaluation finished after training step %d", current_step)
break
except tf.errors.NotFoundError:
// Since the coordinator is on a different job than the TPU worker,
// sometimes the TPU worker does not finish initializing until long after
// the CPU job tells it to start evaluating. In this case, the checkpoint
// file could have been deleted already.
tf.logging.info(
"Checkpoint %s no longer exists, skipping checkpoint", ckpt)
else: // FLAGS.mode == "train" or FLAGS.mode == "train_and_eval"
current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir) // pylint: disable=protected-access,line-too-long
steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size
tf.logging.info("Training for %d steps (%.2f epochs in total). Current"
" step %d.",
FLAGS.train_steps,
FLAGS.train_steps / steps_per_epoch,
current_step)
start_timestamp = time.time() // This time will include compilation time
if FLAGS.mode == "train":
hooks = []
if FLAGS.use_async_checkpointing:
hooks.append(
async_checkpoint.AsyncCheckpointSaverHook(
checkpoint_dir=FLAGS.model_dir,
save_steps=max(100, FLAGS.iterations_per_loop)))
if FLAGS.profile_every_n_steps > 0:
hooks.append(
tpu_profiler_hook.TPUProfilerHook(
save_steps=FLAGS.profile_every_n_steps,
output_dir=FLAGS.model_dir, tpu=FLAGS.tpu)
)
resnet_classifier.train(
input_fn=imagenet_train.input_fn,
max_steps=FLAGS.train_steps,
hooks=hooks)
else:
assert FLAGS.mode == "train_and_eval"
while current_step < FLAGS.train_steps:
// Train for up to steps_per_eval number of steps.
// At the end of training, a checkpoint will be written to --model_dir.
next_checkpoint = min(current_step + FLAGS.steps_per_eval,
FLAGS.train_steps)
resnet_classifier.train(
input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
current_step = next_checkpoint
tf.logging.info("Finished training up to step %d. Elapsed seconds %d.",
next_checkpoint, int(time.time() - start_timestamp))
// Evaluate the model on the most recent model in --model_dir.
// Since evaluation happens in batches of --eval_batch_size, some images
// may be excluded modulo the batch size. As long as the batch size is
// consistent, the evaluated images are also consistent.
tf.logging.info("Starting to evaluate.")
eval_results = resnet_classifier.evaluate(
input_fn=imagenet_eval.input_fn,
steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
tf.logging.info("Eval results at step %d: %s",
next_checkpoint, eval_results)
elapsed_time = int(time.time() - start_timestamp)
tf.logging.info("Finished training up to step %d. Elapsed seconds %d.",
FLAGS.train_steps, elapsed_time)
if FLAGS.export_dir is not None:
// The guide to serve a exported TensorFlow model is at:
// https://www.tensorflow.org/serving/serving_basic
tf.logging.info("Starting to export model.")
export_path = resnet_classifier.export_saved_model(
export_dir_base=FLAGS.export_dir,
serving_input_receiver_fn=imagenet_input.image_serving_input_fn)
if FLAGS.add_warmup_requests:
inference_warmup.write_warmup_requests(
export_path,
FLAGS.model_name,