a6d07af248a7594b8dfedbf8368ddac3f901f3ec,niftynet/layer/loss_segmentation.py,,dice_plus_xent_loss,#Any#Any#Any#,252
Before Change
:return: the loss (cross_entropy + Dice)
if weight_map is not None:
raise NotImplementedError
prediction = tf.cast(prediction, tf.float32)
loss_xent = cross_entropy(prediction, ground_truth)
softmax_of_logits = tf.nn.softmax(prediction, axis=-1)
After Change
one_hot = labels_to_one_hot(ground_truth, num_classes=num_classes)
softmax_of_logits = tf.nn.softmax(prediction)
if weight_map is not None:
weight_map_nclasses = tf.tile(tf.expand_dims(tf.reshape(weight_map, [-1]), 1), [1, num_classes])
dice_numerator = 2.0 * tf.sparse_reduce_sum(weight_map_nclasses * one_hot * softmax_of_logits,
reduction_axes=[0])
dice_denominator = tf.reduce_sum(weight_map_nclasses * softmax_of_logits,
reduction_indices=[0]) + \
tf.sparse_reduce_sum(one_hot * weight_map_nclasses, reduction_axes=[0])
else:
dice_numerator = 2.0 * tf.sparse_reduce_sum(one_hot * softmax_of_logits,reduction_axes=[0])
dice_denominator = tf.reduce_sum(softmax_of_logits, reduction_indices=[0]) + \
tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
// dice_numerator = -2.0 * tf.sparse_reduce_sum(one_hot * softmax_of_logits, reduction_axes=[0])
// dice_denominator = tf.reduce_sum(softmax_of_logits, reduction_indices=[0]) + \
// tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
epsilon = 0.00001
loss_dice = -(dice_numerator + epsilon) / (dice_denominator + epsilon)
dice_numerator = tf.Print(dice_denominator, [dice_numerator, dice_denominator, loss_dice])
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 5
Instances Project Name: NifTK/NiftyNet
Commit Name: a6d07af248a7594b8dfedbf8368ddac3f901f3ec
Time: 2018-12-13
Author: z.eaton-rosen@ucl.ac.uk
File Name: niftynet/layer/loss_segmentation.py
Class Name:
Method Name: dice_plus_xent_loss
Project Name: NifTK/NiftyNet
Commit Name: 29d9f7d43b66da4c25686134ff0366f72934a728
Time: 2018-12-13
Author: z.eaton-rosen@ucl.ac.uk
File Name: niftynet/layer/loss_segmentation.py
Class Name:
Method Name: dice_plus_xent_loss
Project Name: scikit-learn-contrib/DESlib
Commit Name: 250e3b3a31f3f50c691878db75af3ed24de448de
Time: 2021-04-07
Author: rafaelmenelau@gmail.com
File Name: deslib/base.py
Class Name: BaseDS
Method Name: fit