p = (self.logits - shift).exp().clamp(min=finfo.tiny ** 0.5)
d = 2 / p.sum(0)
for _ in range(self.bp_iters):
s = 1 / (p @ d)
d = 2 / (s @ p)
b = s[:, None] * d * p
// Evaluate the Bethe free energy, adapting [4] Eqn 4 to one-two
// matchings.
b = b.clamp(min=finfo.tiny ** 0.5)
After Change
logits = self.logits - shift
d = logits.logsumexp(0) - math.log(2)
for _ in range(self.bp_iters):
s = (logits - d).logsumexp(-1, True)
d = (logits - s).logsumexp(0) - math.log(2)
b = (logits - (d + s)).exp()
def log(x):
return x.clamp(min=finfo.tiny).log()
// Evaluate the Bethe free energy, adapting [4] Eqn 4 to one-two
// matchings.
b_ = (1 - b).clamp(min=0)
internal_energy = -(b * logits.clamp(min=-1 / finfo.eps)).sum()
// Let h be the entropy of matching each destin to one source.
z = b / 2
h = -(z * log(z)).sum(0)
// Then h2 is the entropy of the distribution mapping each destin to an
// unordered pair of sources, equal to the log of the binomial
// coefficient: perplexity * (perplexity - 1) / 2.
h2 = h + log(h.expm1()) - math.log(2)
free_energy = internal_energy - h2.sum() - (b_ * log(b_)).sum()
log_Z = shift.sum() - free_energy
assert torch.isfinite(log_Z)
return log_Z
def log_prob(self, value):