metrics = self.validate_batch(batch, batch_info)
if"loss" in metrics:
losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
if"num_correct" in metrics:
total_correct += metrics["num_correct"]
After Change
metrics = self.validate_batch(batch, batch_info)
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
returnmetric_meters.summary()def validate_batch(self, batch, batch_info):
Calcuates the loss and accuracy over a given batch.