tf.argmax(predictions,
axis=tf.rank(predictions) - 1))
acc_value = tf.reduce_mean(tf.to_float(correct_preds))
// Init result var
accuracy = 0.0
with sess.as_default():
// Compute number of batches
nb_batches = int(math.ceil(float(len(X_test)) / args.batch_size))
assert nb_batches * args.batch_size >= len(X_test)
for batch in range(nb_batches):
if batch % 100 == 0 and batch > 0:
_logger.debug("Batch " + str(batch))
// Must not use the `batch_indices` function here, because it
// repeats some examples.
// It"s acceptable to repeat during training, but not eval.
start = batch * args.batch_size
end = min(len(X_test), start + args.batch_size)
cur_batch_size = end - start
// The last batch may be smaller than all others, so we need to
// account for variable batch size here
feed_dict = {x: X_test[start:end], y: Y_test[start:end]}
if feed is not None:
feed_dict.update(feed)
cur_acc = acc_value.eval(feed_dict=feed_dict)
accuracy += (cur_batch_size * cur_acc)
assert end >= len(X_test)
// Divide by number of examples to get final value
After Change
nb_batches = int(math.ceil(float(len(X_test)) / args.batch_size))
assert nb_batches * args.batch_size >= len(X_test)
X_cur = np.zeros((args.batch_size,) + X_test.shape[1:],
dtype=X_test.dtype)
Y_cur = np.zeros((args.batch_size,) + Y_test.shape[1:],
dtype=Y_test.dtype)
for batch in range(nb_batches):
if batch % 100 == 0 and batch > 0:
_logger.debug("Batch " + str(batch))
// Must not use the `batch_indices` function here, because it
// repeats some examples.
// It"s acceptable to repeat during training, but not eval.
start = batch * args.batch_size
end = min(len(X_test), start + args.batch_size)
cur_batch_size = end - start
X_cur[:cur_batch_size] = X_test[start:end]
Y_cur[:cur_batch_size] = Y_test[start:end]
// The last batch may be smaller than all others, so we need to
// account for variable batch size here
feed_dict = {x: X_cur, y: Y_cur}
if feed is not None:
feed_dict.update(feed)
cur_corr_preds = correct_preds.eval(feed_dict=feed_dict)
accuracy += cur_corr_preds[:cur_batch_size].sum()
assert end >= len(X_test)
// Divide by number of examples to get final value