def check_overflow(self, master_params):
self.is_overflow = False
for p in generate_params(master_params):
if p.grad is not None:
cpu_sum = float(p.grad.float().sum())
if (
cpu_sum == float("inf")
or cpu_sum == -float("inf")
or cpu_sum != cpu_sum
After Change
self.is_overflow = False
for p in generate_params(params):
self.check_overflow_(p.grad)
if self.is_overflow:
break
def update_scale(self):
rAccording to overflow situation, adjust loss scale.
Once overflow happened, we decrease the scale by scale_factor.