with tf.device("/cpu:0"):
batch_size = ground_truth.get_shape()[0].value
ground_truth = tf.reshape(ground_truth, [batch_size, -1])
if weight_map is not None:
weight_map = tf.reshape(weight_map, [batch_size, -1])
// assumes same gt and weight across scales
// prediction should be a list for multi-scale losses
// single scale ``prediction`` is converted to ``[prediction]``
if not isinstance(prediction, (list, tuple)):
prediction = [prediction]
data_loss = []
for ind, pred in enumerate(prediction):
// go through each scale
loss_batch = []
for b_ind, pred_b in enumerate(tf.unstack(pred, axis=0)):
// go through each image in a batch
pred_b = tf.reshape(pred_b, [-1, self._num_classes])
if self._softmax:
pred_b = tf.nn.softmax(
tf.cast(pred_b, dtype=tf.float32))
ground_truth_b = ground_truth[b_ind]
weight_b = None if weight_map is None else weight_map[b_ind]
loss_params = {
"prediction": pred_b,
"ground_truth": ground_truth_b,