60dfcf812eca79017dced46e1189245c050a3fd6,dl/callbacks.py,OptimizerCallback,on_batch_end,#OptimizerCallback#Any#,393
Before Change
// value.zero_grad()
// @TODO: check this one
if len(state._optimizer) > 0:
for key, value in state.loss.items():
value.backward()
if (self.accumulation_counter + 1) \
% self.accumulation_steps == 0:
self.grad_step(state._optimizer)
state.model.zero_grad()
self.accumulation_counter = 0
else:
state.model.zero_grad()
if len(state._optimizer) > 0:
assert len(state._optimizer) == 1, \
After Change
else:
model = state.model
model.zero_grad()
optimizer = state.get_key(
key="optimizer",
inner_key=self.optimizer_key)
loss = state.get_key(
key="loss",
inner_key=self.optimizer_key)
scaled_loss = self.fp16_grad_scale * loss.float()
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 4
Instances
Project Name: Scitator/catalyst
Commit Name: 60dfcf812eca79017dced46e1189245c050a3fd6
Time: 2018-12-10
Author: scitator@gmail.com
File Name: dl/callbacks.py
Class Name: OptimizerCallback
Method Name: on_batch_end
Project Name: Scitator/catalyst
Commit Name: 60dfcf812eca79017dced46e1189245c050a3fd6
Time: 2018-12-10
Author: scitator@gmail.com
File Name: dl/callbacks.py
Class Name: SchedulerCallback
Method Name: step
Project Name: Scitator/catalyst
Commit Name: 60dfcf812eca79017dced46e1189245c050a3fd6
Time: 2018-12-10
Author: scitator@gmail.com
File Name: dl/callbacks.py
Class Name: OptimizerCallback
Method Name: on_epoch_start