72c304fa7cac16ed19d8bc75a017f17c8073dd2f,parlai/core/torch_generator_agent.py,TorchGeneratorAgent,compute_loss,#TorchGeneratorAgent#Any#Any#,613

Before Change


        // save loss to metrics
        notnull = batch.label_vec.ne(self.NULL_IDX)
        target_tokens = notnull.long().sum().item()
        correct = ((batch.label_vec == preds) * notnull).sum().item()
        self.metrics["correct_tokens"] += correct
        self.metrics["nll_loss"] += loss.item()
        self.metrics["num_tokens"] += target_tokens
        loss /= target_tokens  // average loss per token
        if return_output:
            return (loss, model_output)
        else:

After Change



        self.record_local_metric("loss", AverageMetric.many(loss, target_tokens))
        self.record_local_metric("ppl", PPLMetric.many(loss, target_tokens))
        self.record_local_metric(
            "token_acc", AverageMetric.many(correct, target_tokens)
        )
        // actually do backwards loss
        loss = loss.sum()
        loss /= target_tokens.sum()  // average loss per token
        if return_output:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 14

Instances


Project Name: facebookresearch/ParlAI
Commit Name: 72c304fa7cac16ed19d8bc75a017f17c8073dd2f
Time: 2020-02-13
Author: roller@fb.com
File Name: parlai/core/torch_generator_agent.py
Class Name: TorchGeneratorAgent
Method Name: compute_loss


Project Name: facebookresearch/ParlAI
Commit Name: 72c304fa7cac16ed19d8bc75a017f17c8073dd2f
Time: 2020-02-13
Author: roller@fb.com
File Name: parlai/core/torch_generator_agent.py
Class Name: TorchGeneratorAgent
Method Name: compute_loss


Project Name: facebookresearch/ParlAI
Commit Name: 72c304fa7cac16ed19d8bc75a017f17c8073dd2f
Time: 2020-02-13
Author: roller@fb.com
File Name: parlai/core/torch_classifier_agent.py
Class Name: TorchClassifierAgent
Method Name: eval_step


Project Name: facebookresearch/ParlAI
Commit Name: 72c304fa7cac16ed19d8bc75a017f17c8073dd2f
Time: 2020-02-13
Author: roller@fb.com
File Name: parlai/core/torch_classifier_agent.py
Class Name: TorchClassifierAgent
Method Name: train_step