99d4a1e02349f433508fdf95516d2e12cb8b98bb,pyro/distributions/hmm.py,GaussianHMM,rsample,#GaussianHMM#Any#,372

Before Change


        batch_shape = self.batch_shape
        time_shape = self.event_shape[:1]
        init = self.initial_dist.expand(batch_shape).rsample(sample_shape)
        trans = self.transition_dist.expand(batch_shape + time_shape).rsample(sample_shape)
        obs = self.observation_dist.expand(batch_shape + time_shape).rsample(sample_shape)
        mat = self.transition_matrix.expand(batch_shape + time_shape + (self.hidden_dim, self.hidden_dim))
        z = _linear_integrate(init, mat, trans)
        return (z.unsqueeze(-2) @ self.observation_matrix).squeeze(-2) + obs

    def filter(self, value):
        
        Compute posterior over final state given a sequence of observations.

After Change


    def rsample(self, sample_shape=torch.Size()):
        sample_shape = torch.Size(sample_shape)
        obs_dim = self.obs_dim
        hidden_dim = self.hidden_dim
        obs = self._obs.marginalize(right=self.obs_dim).event_pad(left=self.hidden_dim)
        z = _sequential_gaussian_filter_sample(self._init, self._trans + obs, sample_shape)
        perm = torch.cat([torch.arange(hidden_dim, hidden_dim + obs_dim, device=z.device),
                          torch.arange(hidden_dim, device=z.device)])
        x = self._obs.event_permute(perm).condition(z).rsample()
        return x

    def filter(self, value):
        
        Compute posterior over final state given a sequence of observations.
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 11

Instances


Project Name: uber/pyro
Commit Name: 99d4a1e02349f433508fdf95516d2e12cb8b98bb
Time: 2020-01-29
Author: fritzo@uber.com
File Name: pyro/distributions/hmm.py
Class Name: GaussianHMM
Method Name: rsample


Project Name: stanfordnlp/stanza
Commit Name: be91aa08bda8873d839a77932049e39b0ee11577
Time: 2018-10-17
Author: qipeng@users.noreply.github.com
File Name: models/common/biaffine.py
Class Name: BiaffineScorer
Method Name: forward


Project Name: stanfordnlp/stanza
Commit Name: be91aa08bda8873d839a77932049e39b0ee11577
Time: 2018-10-17
Author: qipeng@users.noreply.github.com
File Name: models/common/biaffine.py
Class Name: PairwiseBiaffineScorer
Method Name: forward