val_loss_mean = torch.stack([output["val_loss"]
for output in outputs]).mean()
val_acc_mean = torch.stack([output["num_correct"]
for output in outputs]).sum().float()
val_acc_mean /= (len(outputs) * self.batch_size)
return {"log": {"val_loss": val_loss_mean,
"val_acc": val_acc_mean,
After Change
Compute and log validation loss and accuracy at the epoch level.
val_loss_mean = torch.stack([output["val_loss"] for output in outputs]).mean()
train_acc_mean = self.valid_acc.compute()log_dict = {"val_loss": val_loss_mean, "val_acc": train_acc_mean}
self.log_dict(log_dict, prog_bar=True)
self.log_dict({"step": self.current_epoch})
def configure_optimizers(self):