99d4a1e02349f433508fdf95516d2e12cb8b98bb,pyro/distributions/hmm.py,GaussianHMM,rsample,#GaussianHMM#Any#,372

Before Change



    def rsample(self, sample_shape=torch.Size()):
        batch_shape = self.batch_shape
        time_shape = self.event_shape[:1]
        init = self.initial_dist.expand(batch_shape).rsample(sample_shape)
        trans = self.transition_dist.expand(batch_shape + time_shape).rsample(sample_shape)
        obs = self.observation_dist.expand(batch_shape + time_shape).rsample(sample_shape)
        mat = self.transition_matrix.expand(batch_shape + time_shape + (self.hidden_dim, self.hidden_dim))
        z = _linear_integrate(init, mat, trans)
        return (z.unsqueeze(-2) @ self.observation_matrix).squeeze(-2) + obs

    def filter(self, value):
        
        Compute posterior over final state given a sequence of observations.

After Change


        hidden_dim = self.hidden_dim
        obs = self._obs.marginalize(right=self.obs_dim).event_pad(left=self.hidden_dim)
        z = _sequential_gaussian_filter_sample(self._init, self._trans + obs, sample_shape)
        perm = torch.cat([torch.arange(hidden_dim, hidden_dim + obs_dim, device=z.device),
                          torch.arange(hidden_dim, device=z.device)])
        x = self._obs.event_permute(perm).condition(z).rsample()
        return x

    def filter(self, value):
        
        Compute posterior over final state given a sequence of observations.
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 8

Instances


Project Name: uber/pyro
Commit Name: 99d4a1e02349f433508fdf95516d2e12cb8b98bb
Time: 2020-01-29
Author: fritzo@uber.com
File Name: pyro/distributions/hmm.py
Class Name: GaussianHMM
Method Name: rsample


Project Name: OpenMined/PySyft
Commit Name: e50f57e068eea78df52fadc8c398c81088e7df2e
Time: 2020-07-29
Author: theo.leffyr@gmail.com
File Name: syft/frameworks/torch/mpc/fss.py
Class Name:
Method Name: bit_decomposition


Project Name: Kaixhin/Rainbow
Commit Name: fbc23881c651d69da1f7ba92bdab78009d33bf94
Time: 2020-08-20
Author: 32273096+Aladoro@users.noreply.github.com
File Name: memory.py
Class Name: SegmentTree
Method Name: _retrieve