logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
if not all(k in logging_output for k in ["ntokens", "nsentences"]):
raise Exception(
(
After Change
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.get_criterion()
)
sample_size = sum(sample_sizes)
if not all(k in logging_output for k in ["ntokens", "nsentences"]):
raise Exception(
(