aab3902d4a7d55f5a86058854adc36b8a12c873f,catalyst/dl/callbacks/base.py,OptimizerCallback,on_batch_end,#OptimizerCallback#Any#,202

Before Change


            return

        self._accumulation_counter += 1
        if not self.fp16:
            model = state.model
            optimizer = state.get_key(
                key="optimizer", inner_key=self.optimizer_key
            )
            loss.backward()

            if (self._accumulation_counter + 1) % self.accumulation_steps == 0:
                self.grad_step(
                    optimizer=optimizer,
                    optimizer_wd=self._optimizer_wd,
                    grad_clip_fn=self.grad_clip_fn
                )
                model.zero_grad()
                self._accumulation_counter = 0
        else:
            model = state.model
            model.zero_grad()
            optimizer = state.get_key(
                key="optimizer", inner_key=self.optimizer_key
            )
            loss = state.get_key(key="loss", inner_key=self.optimizer_key)
            scaled_loss = self.fp16_grad_scale * loss.float()
            scaled_loss.backward()

            master_params = list(optimizer.param_groups[0]["params"])
            model_params = list(
                filter(lambda p: p.requires_grad, model.parameters())
            )
            copy_grads(source=model_params, target=master_params)
            for param in master_params:
                param.grad.data.mul_(1. / self.fp16_grad_scale)
            self.grad_step(
                optimizer=optimizer,
                optimizer_wd=self._optimizer_wd,
                grad_clip_fn=self.grad_clip_fn
            )
            copy_params(source=master_params, target=model_params)
            torch.cuda.synchronize()

    def on_epoch_end(self, state):
        optimizer = state.get_key(
            key="optimizer", inner_key=self.optimizer_key
        )

After Change



    def on_batch_end(self, state):
        loss = state.get_key(key="loss", inner_key=self.loss_key)
        if isinstance(loss, dict):
            loss = list(loss.values())
        if isinstance(loss, list):
            loss = torch.mean(torch.stack(loss))

        if self.prefix is not None:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: catalyst-team/catalyst
Commit Name: aab3902d4a7d55f5a86058854adc36b8a12c873f
Time: 2019-05-20
Author: ekhvedchenya@gmail.com
File Name: catalyst/dl/callbacks/base.py
Class Name: OptimizerCallback
Method Name: on_batch_end


Project Name: pysb/pysb
Commit Name: 590e29dd3dca1f837ab70c67f3a628c7ee032ac9
Time: 2017-05-01
Author: alubbock@users.noreply.github.com
File Name: pysb/simulator/cupsoda.py
Class Name: CupSodaSimulator
Method Name: _get_cmatrix