42dc18f2a13441eefdfceed905843a3d19b27072,pyro/distributions/one_one_matching.py,OneOneMatching,log_partition_function,#OneOneMatching#,100
Before Change
for _ in range(self.bp_iters):
s = 1 / (p @ d)
d = 1 / (s @ p)
b = s[:, None] * d * p
// Evaluate the Bethe free energy.
b = b.clamp(min=finfo.tiny ** 0.5)
b_ = (1 - b).clamp(min=finfo.tiny ** 0.5)
free_energy = (b * (b.log() - p.log())).sum() - (b_ * b_.log()).sum()
return shift.sum() - free_energy
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
After Change
shift = self.logits.max(1, True).values
shift.data.clamp_(min=finfo.min, max=finfo.max)
logits = self.logits - shift
d = logits.logsumexp(0)
for _ in range(self.bp_iters):
s = (logits - d).logsumexp(-1, True)
d = (logits - s).logsumexp(0)
b = (logits - (d + s)).exp()
def log(x):
return x.clamp(min=finfo.tiny).log()
// Evaluate the Bethe free energy.
b_ = (1 - b).clamp(min=0)
logits = logits.clamp(min=-1 / finfo.eps)
free_energy = (b * (log(b) - logits)).sum() - (b_ * log(b_)).sum()
log_Z = shift.sum() - free_energy
assert torch.isfinite(log_Z)
return log_Z
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 9
Instances
Project Name: uber/pyro
Commit Name: 42dc18f2a13441eefdfceed905843a3d19b27072
Time: 2020-12-07
Author: fritz.obermeyer@gmail.com
File Name: pyro/distributions/one_one_matching.py
Class Name: OneOneMatching
Method Name: log_partition_function
Project Name: mittagessen/kraken
Commit Name: 90092a23db57abba13808dbf96100e07fc4797ac
Time: 2018-05-14
Author: mittagessen@l.unchti.me
File Name: kraken/lib/ctc.py
Class Name: _CTC
Method Name: backward
Project Name: uber/pyro
Commit Name: 42dc18f2a13441eefdfceed905843a3d19b27072
Time: 2020-12-07
Author: fritz.obermeyer@gmail.com
File Name: pyro/distributions/one_two_matching.py
Class Name: OneTwoMatching
Method Name: log_partition_function