// Parse the fetched records to input tensors for model function.
dataset = dataset.map(_dataset_parser, num_parallel_calls=64)
dataset = dataset.prefetch(batch_size)
dataset = dataset.batch(batch_size, drop_remainder=True)
def _process_example(images, score_targets, box_targets, source_ids,
image_info, boxes, is_crowds, areas, classes,
cropped_gt_masks):
Processes one batch of data.
// Transposes images for TPU performance.
// Given the batch size, the batch dimesion (N) goes to either the minor
// ((H, W, C, N) when N > C) or the second-minor ((H, W, N, C) when N < C)
// dimension. Here, we assume N is 4 or 8 and C is 3, so we use
// (H, W, C, N).
if (params["transpose_input"] and
self._mode == tf.estimator.ModeKeys.TRAIN):
images = tf.transpose(images, [1, 2, 3, 0])
labels = {}
for level in range(params["min_level"], params["max_level"] + 1):
labels["score_targets_%d" % level] = score_targets[level]
labels["box_targets_%d" % level] = box_targets[level]
// Concatenate groundtruth annotations to a tensor.
groundtruth_data = tf.concat([boxes, is_crowds, areas, classes], axis=2)
labels["source_ids"] = source_ids
labels["groundtruth_data"] = groundtruth_data
labels["image_info"] = image_info
labels["cropped_gt_masks"] = cropped_gt_masks
if self._mode == tf.estimator.ModeKeys.PREDICT:
features = dict(
images=images,
image_info=image_info,
groundtruth_data=groundtruth_data,
source_ids=source_ids)
return features
else:
return images, labels
dataset = dataset.map(_process_example)
dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
if self._num_examples > 0:
dataset = dataset.take(self._num_examples)