6b7ac72472c9d549dea45f35669e876c4c8ee2e9,arviz/data/io_pyro.py,PyroConverter,posterior_to_xarray,#PyroConverter#,50
Before Change
// Do not make pyro a requirement
from pyro.infer import EmpiricalMarginal
try: // Try pyro>=0.3 release syntax
data = {
name: utils.expand_dims(samples.enumerate_support().squeeze())
if self.posterior.num_chains == 1
else samples.enumerate_support().squeeze()
for name, samples in self.posterior.marginal(
sites=self.latent_vars
).empirical.items()
}
except AttributeError: // Use pyro<0.3 release syntax
data = {}
for var_name in self.latent_vars:
// pylint: disable=no-member
samples = EmpiricalMarginal(
self.posterior, sites=var_name
).get_samples_and_weights()[0]
samples = samples.numpy().squeeze()
data[var_name] = utils.expand_dims(samples)
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
def observed_data_to_xarray(self):
Convert observed data to xarray.
After Change
def posterior_to_xarray(self):
Convert the posterior to an xarray dataset.
data = self.posterior.get_samples(group_by_chain=True)
data = {k: v.detach().cpu().numpy() for k, v in data.items()}
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
@requires("posterior")
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 3
Instances Project Name: arviz-devs/arviz
Commit Name: 6b7ac72472c9d549dea45f35669e876c4c8ee2e9
Time: 2019-10-31
Author: fehiepsi@gmail.com
File Name: arviz/data/io_pyro.py
Class Name: PyroConverter
Method Name: posterior_to_xarray
Project Name: calico/basenji
Commit Name: 28f6dbec4bee2572fa7f94445d63cebb2de6dc9b
Time: 2019-09-27
Author: drk@calicolabs.com
File Name: bin/tfr_hdf5.py
Class Name:
Method Name: read_tfr
Project Name: ray-project/ray
Commit Name: 22ccc43670dac93eb7fe81520a84cf3979d05693
Time: 2020-04-06
Author: sven@anyscale.io
File Name: rllib/utils/test_utils.py
Class Name:
Method Name: check