for batch_id, batch in enumerate(batches):
pre_batch()
loss, metric_data = model.train_batch(batch)
backprop(loss, timer)
if report_metric:
metric_reporter.add_batch_stats(
batch_id, *metric_data, **metric_reporter.batch_context(batch)
)
After Change
pre_batch()
with time_utils.time("model.train_batch"):
loss, metric_data = model.train_batch(batch)
with time_utils.time("backprop"):
backprop(loss)
if report_metric:
with time_utils.time("add metrics"):
metric_reporter.add_batch_stats(
batch_id, *metric_data, **metric_reporter.batch_context(batch)