d7148f390db12a3f7e4b43ab584648b2e4a1882d,pyro/distributions/spanning_tree.py,SpanningTree,log_partition_function,#SpanningTree#,109
Before Change
// See https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem
V = self.num_vertices
grid = make_complete_graph(V)
shift = self.edge_logits.detach().max()
edge_probs = (self.edge_logits - shift).exp()
adjacency = torch.zeros(V, V, dtype=edge_probs.dtype)
adjacency[grid[0], grid[1]] = edge_probs
adjacency[grid[1], grid[0]] = edge_probs
laplacian = adjacency.sum(-1).diag() - adjacency
truncated = laplacian[:-1, :-1]
try:
import gpytorch
log_det = gpytorch.lazy.NonLazyTensor(truncated).logdet()
except ImportError:
log_det = torch.cholesky(truncated).diag().log().sum() * 2
return log_det + shift * (V - 1)
def log_prob(self, edges):
if self._validate_args:
self.validate_edges(edges)
After Change
// See https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem
V = self.num_vertices
v1, v2 = make_complete_graph(V).unbind(0)
logits = self.edge_logits.new_full((V, V), -math.inf)
logits[v1, v2] = self.edge_logits
logits[v2, v1] = self.edge_logits
log_diag = logits.logsumexp(-1)
// Numerically stabilize so that laplacian has 1"s on the diagonal.
shift = 0.5 * log_diag
laplacian = torch.eye(V) - (logits - shift - shift[:, None]).exp()
truncated = laplacian[:-1, :-1]
try:
import gpytorch
log_det = gpytorch.lazy.NonLazyTensor(truncated).logdet()
except ImportError:
log_det = torch.cholesky(truncated).diag().log().sum() * 2
return log_det + log_diag[:-1].sum()
def log_prob(self, edges):
if self._validate_args:
self.validate_edges(edges)
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 6
Instances
Project Name: uber/pyro
Commit Name: d7148f390db12a3f7e4b43ab584648b2e4a1882d
Time: 2020-11-07
Author: fritz.obermeyer@gmail.com
File Name: pyro/distributions/spanning_tree.py
Class Name: SpanningTree
Method Name: log_partition_function
Project Name: maciejkula/spotlight
Commit Name: bed4f2dd50ff54eb7629362c3d03a3b758e0745a
Time: 2017-08-09
Author: maciej.kula@gmail.com
File Name: spotlight/layers.py
Class Name: BloomEmbedding
Method Name: forward
Project Name: masa-su/pixyz
Commit Name: 8a0793f34b6c8f7b8e5164da424504ef1a4fb744
Time: 2018-11-04
Author: masa@weblab.t.u-tokyo.ac.jp
File Name: Tars/losses/adversarial_loss.py
Class Name: AdversarialJSDivergence
Method Name: g_loss