31bf9a3cbbf0739beb745389ee64f6b9200f6d53,pytext/trainers/trainer.py,Trainer,run_step,#Trainer#Any#Any#Any#Any#,415
Before Change
model = state.model
self.zero_grads(state)
for i, (batch_id, (inputs, targets, context)) in enumerate(samples):
if cuda.DISTRIBUTED_WORLD_SIZE > 1:
// Whenever *samples* contains more than one mini-batch, we
// want to accumulate gradients locally and only call
// all-reduce in the last backwards pass.
if i < sample_size - 1:
// sync gradients in the last sample backward
model.accumulate_gradients(True)
else:
model.accumulate_gradients(False)
// pass context to model to use in forward call if needed
model.contextualize(context)
with timing.time("model.forward"):
logits = model(*inputs)
After Change
model = state.model
self.zero_grads(state)
for idx, (batch_id, (inputs, targets, context)) in enumerate(samples):
with maybe_no_sync(model, idx, sample_size):
// pass context to model to use in forward call if needed
model.contextualize(context)
with timing.time("model.forward"):
logits = model(*inputs)
with timing.time("compute loss"):
loss = precision.maybe_float(
model.get_loss(logits, targets, context)
)
if BatchContext.IGNORE_LOSS in context:
loss *= 0
elif sample_size > 1:
// gradients averaged per batch and accumulated across samples.
// divide sample_size to let gradients averaged per example
loss = loss / sample_size
self.backprop(state, loss)
if report_metric:
with timing.time("get pred"):
preds, scores = model.get_pred(
logits, targets, context, state.stage, *inputs
In pattern: SUPERPATTERN
Frequency: 4
Non-data size: 6
Instances
Project Name: facebookresearch/pytext
Commit Name: 31bf9a3cbbf0739beb745389ee64f6b9200f6d53
Time: 2019-07-10
Author: chenyangyu@instagram.com
File Name: pytext/trainers/trainer.py
Class Name: Trainer
Method Name: run_step
Project Name: pytorch/fairseq
Commit Name: b625d53d029afbd52a508bfcf8f7cd6d45aa0f79
Time: 2019-06-20
Author: myleott@fb.com
File Name: fairseq/trainer.py
Class Name: Trainer
Method Name: train_step
Project Name: facebookresearch/pytext
Commit Name: 31bf9a3cbbf0739beb745389ee64f6b9200f6d53
Time: 2019-07-10
Author: chenyangyu@instagram.com
File Name: pytext/trainers/trainer.py
Class Name: TaskTrainer
Method Name: run_step