// Main optimization loop.
for i in range(steps):
self._session.run(self._clear_gradients)
for j in range(self.meta_batch_size):
self.learner.select_task()
inputs = self.learner.get_batch()
feed_dict = {}
feed_dict[self._global_step] = i
for k in range(len(inputs)):
feed_dict[self._input_placeholders[k]] = inputs[k]
feed_dict[self._meta_placeholders[k]] = inputs[k]
self._session.run(self._add_gradients, feed_dict=feed_dict)
self._session.run(self._meta_train_op)
// Do checkpointing.
After Change
// Main optimization loop.
learner = self.learner
variables = learner.variables
for i in range(steps):
for j in range(self.meta_batch_size):
learner.select_task()
meta_loss, meta_gradients = self._compute_meta_loss(
learner.get_batch(), learner.get_batch(), variables)
if j == 0:
summed_gradients = meta_gradients
else:
summed_gradients = [
s + g for s, g in zip(summed_gradients, meta_gradients)
]
self._tf_optimizer.apply_gradients(zip(summed_gradients, variables))
// Do checkpointing.
if i == steps - 1 or time.time() >= checkpoint_time + checkpoint_interval: