sess = tf.Session(tpu_grpc_url)
sess.run(tf.contrib.tpu.initialize_system())
sess.run(tf.global_variables_initializer())
for i in range(args.max_steps):
// the tensor values in the TPU function are returned in a list, and the operations in the TPU function are called with no return value
global_step, loss = sess.run(train_on_tpu)
After Change
save_steps=args.save_checkpoints_steps,
)
summary = tf.summary.scalar("loss", loss_tensor)
summary_saver_hook = tf.train.SummarySaverHook(
save_steps=args.save_checkpoints_steps,
output_dir=args.model_dir,
summary_op=summary
)
// loss_summary = tf.summary.scalar("loss", loss_tensor)
// summary = tf.summary.merge_all()
// get the TPU resource"s grpc url
// Note: when running on CMLE, args.tpu should be left as None
tpu_grpc_url = TPUClusterResolver(tpu=args.tpu).get_master()
// sess = tf.Session(tpu_grpc_url)
sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator(master=tpu_grpc_url),
hooks=[checkpoint_saver_hook, summary_saver_hook]
)