if grad.is_sparse:
raise RuntimeError("Adamax does not support sparse gradients")
p_data_fp32 = p.data.float()
state = self.state[p]
// State initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_inf"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
After Change
raise RuntimeError("Adamax does not support sparse gradients")
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state = self.state[p]
// State initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_inf"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)