def _replace_target_params(self):
replace_ops = []
for layer, params in enumerate(self._eval_net_params):
replace_op = [tf.assign(self._target_net_params[layer][W_b], params[W_b]) for W_b in range(2)]
replace_ops.append(replace_op)
self.sess.run(replace_ops)
After Change
return action
def _replace_target_params(self):
t_params = tf.get_collection("target_net_params")
e_params = tf.get_collection("eval_net_params")
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
def learn(self):
// check to replace target parameters