log_device_placement=True))
print("Created server")
worker_agent = DistributedAgent(self.agent_config, scope, self.task_index, cluster)
print("Created agent")
variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
init_op = tf.variables_initializer(variables_to_save)
local_init_op = tf.variables_initializer(tf.local_variables() + [v for v in tf.global_variables() if v.name.startswith("local")])
After Change
config=tf.ConfigProto(intra_op_parallelism_threads=1,
inter_op_parallelism_threads=2,
log_device_placement=True))
self.logger.debug("Created server")
worker_agent = DistributedAgent(self.agent_config, scope, self.task_index, cluster)
self.logger.debug("Created agent")
variables_to_save = [v for v in tf.global_variables() if not v.name.startswith("local")]
init_op = tf.variables_initializer(variables_to_save)
local_init_op = tf.variables_initializer(tf.local_variables() + [v for v in tf.global_variables() if v.name.startswith("local")])
init_all_op = tf.global_variables_initializer()
def init_fn(session):
// sess.run(worker_agent.model.init_op)
session.run(init_all_op)
config = tf.ConfigProto(device_filters=["/job:ps", "/job:worker/task:{}/cpu:0".format(self.task_index)])
supervisor = tf.train.Supervisor(is_chief=(self.task_index == 0),
logdir="/tmp/train_logs",
global_step=worker_agent.model.global_step,
init_op=init_op,
local_init_op=local_init_op,
init_fn=init_fn,
ready_op=tf.report_uninitialized_variables(variables_to_save),
saver=worker_agent.model.saver)
// summary_op=tf.summary.merge_all(),
// summary_writer=worker_agent.model.summary_writer)
runner = ThreadRunner(worker_agent, deepcopy(self.environment),
self.max_episode_steps, self.local_steps, preprocessor=self.preprocessor,
repeat_actions=self.repeat_actions)
// Connecting to parameter server
self.logger.debug("Connecting to session..")
self.logger.info("Server target = " + str(server.target))
with supervisor.managed_session(server.target, config=config) as session, session.as_default():
self.logger.info("Established session, starting runner..")
session.run(worker_agent.model.assign_global_to_local)
runner.start_thread(session)
self.logger.debug("Runner started")
global_step_count = worker_agent.increment_global_step()
self.logger.debug("Got global step count")
while not supervisor.should_stop() and global_step_count < self.global_steps:
runner.update()
global_step_count = worker_agent.increment_global_step()
self.logger.debug("Global step count: {}".format(global_step_count))
self.logger.info("Stopping supervisor")
supervisor.stop()