979b8c9efa551e8c948a4aca145367a2d87ac8d6,test/distributions/test_multitask_multivariate_normal.py,TestMultiTaskMultivariateNormal,test_multivariate_normal_correlated_sampels,#TestMultiTaskMultivariateNormal#Any#,112
Before Change
device = torch.device("cuda") if cuda else torch.device("cpu")
mean = torch.tensor([[0, 1], [2, 3]], dtype=torch.float, device=device)
variance = 1 + torch.arange(4, dtype=torch.float, device=device)
covmat = torch.diag(variance)
mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat)
base_samples = mtmvn.get_base_samples(torch.Size((3, 4)))
self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 2]))
base_samples = mtmvn.get_base_samples()
self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([2, 2]))
After Change
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
mean = torch.tensor([[0, 1], [2, 3]], dtype=dtype, device=device)
variance = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=device)
covmat = variance.view(-1).diag()
mtmvn = MultitaskMultivariateNormal(mean=mean, covariance_matrix=covmat)
base_samples = mtmvn.get_base_samples(torch.Size([3, 4]))
self.assertTrue(mtmvn.sample(base_samples=base_samples).shape == torch.Size([3, 4, 2, 2]))
base_samples = mtmvn.get_base_samples()
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 6
Instances
Project Name: cornellius-gp/gpytorch
Commit Name: 979b8c9efa551e8c948a4aca145367a2d87ac8d6
Time: 2019-02-26
Author: balandat@fb.com
File Name: test/distributions/test_multitask_multivariate_normal.py
Class Name: TestMultiTaskMultivariateNormal
Method Name: test_multivariate_normal_correlated_sampels
Project Name: cornellius-gp/gpytorch
Commit Name: 2ace85bfe963b8399cc8ec0cdd859a8e0ee31dbd
Time: 2018-12-11
Author: gpleiss@gmail.com
File Name: test/lazy/_lazy_tensor_test_case.py
Class Name: LazyTensorTestCase
Method Name: test_diag
Project Name: cornellius-gp/gpytorch
Commit Name: 979b8c9efa551e8c948a4aca145367a2d87ac8d6
Time: 2019-02-26
Author: balandat@fb.com
File Name: test/distributions/test_multitask_multivariate_normal.py
Class Name: TestMultiTaskMultivariateNormal
Method Name: test_multivariate_normal_batch_correlated_sampels