99d4a1e02349f433508fdf95516d2e12cb8b98bb,pyro/distributions/hmm.py,GaussianHMM,rsample,#GaussianHMM#Any#,372
Before Change
def rsample(self, sample_shape=torch.Size()):
batch_shape = self.batch_shape
time_shape = self.event_shape[:1]
init = self.initial_dist.expand(batch_shape).rsample(sample_shape)
trans = self.transition_dist.expand(batch_shape + time_shape).rsample(sample_shape)
obs = self.observation_dist.expand(batch_shape + time_shape).rsample(sample_shape)
mat = self.transition_matrix.expand(batch_shape + time_shape + (self.hidden_dim, self.hidden_dim))
z = _linear_integrate(init, mat, trans)
return (z.unsqueeze(-2) @ self.observation_matrix).squeeze(-2) + obs
def filter(self, value):
Compute posterior over final state given a sequence of observations.
After Change
hidden_dim = self.hidden_dim
obs = self._obs.marginalize(right=self.obs_dim).event_pad(left=self.hidden_dim)
z = _sequential_gaussian_filter_sample(self._init, self._trans + obs, sample_shape)
perm = torch.cat([torch.arange(hidden_dim, hidden_dim + obs_dim, device=z.device),
torch.arange(hidden_dim, device=z.device)])
x = self._obs.event_permute(perm).condition(z).rsample()
return x
def filter(self, value):
Compute posterior over final state given a sequence of observations.
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 8
Instances
Project Name: uber/pyro
Commit Name: 99d4a1e02349f433508fdf95516d2e12cb8b98bb
Time: 2020-01-29
Author: fritzo@uber.com
File Name: pyro/distributions/hmm.py
Class Name: GaussianHMM
Method Name: rsample
Project Name: OpenMined/PySyft
Commit Name: e50f57e068eea78df52fadc8c398c81088e7df2e
Time: 2020-07-29
Author: theo.leffyr@gmail.com
File Name: syft/frameworks/torch/mpc/fss.py
Class Name:
Method Name: bit_decomposition
Project Name: Kaixhin/Rainbow
Commit Name: fbc23881c651d69da1f7ba92bdab78009d33bf94
Time: 2020-08-20
Author: 32273096+Aladoro@users.noreply.github.com
File Name: memory.py
Class Name: SegmentTree
Method Name: _retrieve