for train_metric in train_metrics:
train_metric.tf_summaries(step_metrics=train_metrics[:2])
summary_op = tf.contrib.summary.all_summary_ops()
with eval_summary_writer.as_default(), \
tf.compat.v2.summary.record_if(True):
for eval_metric in eval_metrics:
eval_metric.tf_summaries()
init_agent_op = tf_agent.initialize()
with tf.compat.v1.Session() as sess:
tf.contrib.summary.initialize(graph=tf.compat.v1.get_default_graph())
// Initialize the graph.
train_checkpointer.initialize_or_restore(sess)
rb_checkpointer.initialize_or_restore(sess)
sess.run(iterator.initializer)
// TODO(sguada) Remove once Periodically can be saved.
common_utils.initialize_uninitialized_variables(sess)
sess.run(init_agent_op)
tf.contrib.summary.initialize(session=sess)
logging.info("Collecting initial experience.")
sess.run(initial_collect_op)
// Compute evaluation metrics.
global_step_val = sess.run(global_step)
metric_utils.compute_summaries(
eval_metrics,
eval_py_env,
eval_py_policy,
num_episodes=num_eval_episodes,
global_step=global_step_val,
callback=eval_metrics_callback,
log=True,
)
collect_call = sess.make_callable(collect_op)
train_step_call = sess.make_callable([loss_info, summary_op, global_step])
timed_at_step = sess.run(global_step)
time_acc = 0
After Change
init_agent_op = tf_agent.initialize()
with tf.compat.v1.Session() as sess:
sess.run(train_summary_writer.init())sess.run(eval_summary_writer.init())
// Initialize the graph.
train_checkpointer.initialize_or_restore(sess)
rb_checkpointer.initialize_or_restore(sess)
sess.run(iterator.initializer)