ifisinstance(x, tf.SparseTensor):
// This means reduce_instance_dims=False.// TODO(b/112656428): Support SparseTensors with rank other than 2.if x.get_shape().ndims != 2:
raise NotImplementedError(
"Mean and var only support SparseTensors with rank 2") col_count, col_indices = x.dense_shape[1], x.indices[:, 1]
x_sum = tf.math.unsorted_segment_sum(x.values, col_indices, col_count)
x_mean = tf.where(tf.math.greater(x_count, 0),
x_sum / x_count,