with tf.device("/cpu:0"):
batch_size = ground_truth.get_shape()[0].value
ground_truth = tf.reshape(ground_truth, [batch_size, -1])
if weight_map is not None:
weight_map = tf.reshape(weight_map, [batch_size, -1])
After Change
// size: (n_voxels, num_classes)
// if the ground_truth has only one channel, the shape
// becomes: (n_voxels,)
spatial_shape = pred_b.get_shape().as_list()[:-1]
ref_shape = spatial_shape + [-1]
ground_truth_b = tf.reshape(ground_truth[b_ind], ref_shape)
if ground_truth_b.get_shape().as_list()[-1] == 1:
ground_truth_b = tf.squeeze(ground_truth_b, axis=-1)
if weight_map is not None:
weight_b = tf.reshape(weight_map[b_ind], ref_shape)
if weight_b.get_shape().as_list()[-1] == 1:
weight_b = tf.squeeze(weight_b, axis=-1)
else:
weight_b = None