// TODO: Doesn"t make much sense over training data, averaging with initial
// (bad) predictions.
auc, update_auc_op = tf.metrics.auc(labels, logits, curve="PR")
tf.add_to_collection("metric_ops", update_auc_op)
tf.add_to_collection("metrics", auc)
tf.summary.scalar("auc", auc)
After Change
def metrics(logits, labels):
// TODO: AUC doesn"t make much sense over training data, averaging with
// initial (bad) predictions.
normalized_logits = tf.sigmoid(logits)
// Add one AUC metric per class, so we can see individual performance too.
for cls in range(logits.shape[1]):
auc, _ = tf.metrics.auc(
labels[:, cls], normalized_logits[:, cls],
curve="PR", name=f"iauc/{cls}",
metrics_collections="metrics",
updates_collections="metric_ops",