metrics = self.validate_batch(batch, batch_info)
if "loss" in metrics:
losses.update(
metrics["loss"], n=metrics.get("num_samples", 1))
if "num_correct" in metrics:
total_correct += metrics["num_correct"]
After Change
metrics = self.validate_batch(batch, batch_info)
metric_meters.update(metrics, n=metrics.pop(NUM_SAMPLES, 1))
return metric_meters.summary()
def validate_batch(self, batch, batch_info):
Calcuates the loss and accuracy over a given batch.