if not self.is_training:
dataset = dataset.apply(batching.filter_irregular_batches(batch_size))
dataset = dataset.map(
lambda images, labels: (tf.transpose(images, [1, 2, 3, 0]), labels),
num_parallel_calls=8)
// For XLA, we must used fixed shapes. Because we repeat the source training
// dataset indefinitely, this is not a dangerous operation.
//
// When evaluating, prevent accidentally evaluating the same image twice by
// dropping the final batch if it is less than a full batch size. As long as
// this validation is done with consistent batch size, exactly the same
// images will be used.
def set_shapes(images, labels):
images.set_shape(images.get_shape().merge_with(
tf.TensorShape([None, None, None, batch_size])))
labels.set_shape(labels.get_shape().merge_with(
tf.TensorShape([batch_size])))
return images, labels
if self.is_training:
dataset = dataset.map(set_shapes)dataset = dataset.prefetch(32) // Prefetch overlaps in-feed with training
return dataset // Must return the dataset and not tensors for high perf!