x = train_data.next(args.init_batch_size) // manually retrieve exactly init_batch_size examples
train_data.reset() // rewind the iterator back to 0 to do one full epoch
print("initializing the model...")
sess.run(initializer,{x_init: prepro(x)})
if args.load_params:
ckpt_file = args.save_dir + "/params_" + args.data_set + ".ckpt"
print("restoring parameters from", ckpt_file)
After Change
// train for one epoch
train_losses = []
for d in train_data:
feed_dict = make_feed_dict(d)
// forward/backward/update model on each gpu
lr *= args.lr_decay
feed_dict.update({ tf_lr: lr })
l,_ = sess.run([bits_per_dim, optimizer], feed_dict)