p = (self.logits - shift).exp().clamp(min=finfo.tiny ** 0.5)
d = 2 / p.sum(0)
for _ in range(self.bp_iters):
s = 1 / (p @ d)
d = 2 / (s @ p)
b = s[:, None] * d * p
// Evaluate the Bethe free energy, adapting [4] Eqn 4 to one-two
// matchings.
After Change
shift = self.logits.max(1, True).values
shift.data.clamp_(min=finfo.min, max=finfo.max)
logits = self.logits - shift
d = logits.logsumexp(0) - math.log(2)
for _ in range(self.bp_iters):
s = (logits - d).logsumexp(-1, True)
d = (logits - s).logsumexp(0) - math.log(2)
b = (logits - (d + s)).exp()
def log(x):