raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
amsgrad = group["amsgrad"]
p_data_fp32 = p.data.float()
state = self.state[p]
// State initialization
if len(state) == 0:
state["step"] = 0
// Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p_data_fp32)
// Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
if amsgrad:
// Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad:
After Change
amsgrad = group["amsgrad"]
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state = self.state[p]
// State initialization
if len(state) == 0:
state["step"] = 0
// Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p_data_fp32)
// Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
if amsgrad:
// Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
if amsgrad:
state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
if amsgrad: