// batch_dims and diag
res = kernel(a, b, batch_dims=(0, 2)).diag()
actual = torch.cat([actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
self.assertLess(torch.norm(res - actual), 1e-5)
def test_ard_separate_batch(self):
After Change
// batch_dims
double_batch_a = scaled_a.unsqueeze(0).transpose(0, -1)
double_batch_b = scaled_b.unsqueeze(0).transpose(0, -1)
actual = double_batch_a.transpose(-1, -2).unsqueeze(-1) - double_batch_b.transpose(-1, -2).unsqueeze(-2)
actual = actual.pow(2).mul_(-0.5).exp().view(3, 2, 2, 2)
res = kernel(a, b, batch_dims=(0, 2)).evaluate()
self.assertLess(torch.norm(res - actual), 1e-5)
// batch_dims and diag
res = kernel(a, b, batch_dims=(0, 2)).diag()
actual = actual.diagonal(dim1=-2, dim2=-1)
self.assertLess(torch.norm(res - actual), 1e-5)
def test_ard_separate_batch(self):
a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float)