p += eps * d
self.mutator.reset()
loss = self.loss(self.model(trn_X), trn_y)
if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) // dalpha { L_trn(w+) }
elif e < 0:
After Change
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]:
// w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas // dalpha { L_trn(w+) }, // dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian