if epoch_batches is None:
batcher.reset()
return np.mean(train_loss), global_step
def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None):
Execute one training epoch, using TFRecords data.
After Change
// initialize training loss
train_loss = []
batch_sizes = []
global_step = 0
// setup feed dict
fd = self.set_mode("train")
// get first batch
Xb, Yb, NAb, Nb = batcher.next()
batch_num = 0
while Xb is not None and (epoch_batches is None or batch_num < epoch_batches):
// update feed dict
fd[self.inputs_ph] = Xb
fd[self.targets_ph] = Yb
if no_steps:
run_returns = sess.run([self.merged_summary, self.loss_train] + \
self.update_ops, feed_dict=fd)
summary, loss_batch = run_returns[:2]
else:
run_ops = [self.merged_summary, self.loss_train, self.global_step, self.step_op]
run_ops += self.update_ops
summary, loss_batch, global_step = sess.run(run_ops, feed_dict=fd)[:3]
// add summary
if sum_writer is not None:
sum_writer.add_summary(summary, global_step)
// accumulate loss
train_loss.append(loss_batch)
batch_sizes.append(Nb)
// next batch
Xb, Yb, NAb, Nb = batcher.next()
batch_num += 1
// reset training batcher if epoch considered all of the data
if epoch_batches is None:
batcher.reset()
avg_loss = np.average(train_loss, weights=batch_sizes)return avg_loss, global_step
def train_epoch_tfr(self, sess, sum_writer=None, epoch_batches=None):
Execute one training epoch, using TFRecords data.