f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b,gpytorch/lazy/root_lazy_variable.py,RootLazyVariable,_batch_get_indices,#RootLazyVariable#Any#Any#Any#,58
Before Change
return self
def _batch_get_indices(self, batch_indices, left_indices, right_indices):
outer_size = batch_indices.size(0)
inner_size = self.root.size(-1)
inner_indices = Variable(right_indices.data.new(inner_size))
torch.arange(0, inner_size, out=inner_indices.data)
left_vals = self.root._batch_get_indices(
_outer_repeat(batch_indices, inner_size),
_outer_repeat(left_indices, inner_size),
_inner_repeat(inner_indices, outer_size),
)
right_vals = self.root.transpose(-1, -2)._batch_get_indices(
_outer_repeat(batch_indices, inner_size),
_inner_repeat(inner_indices, outer_size),
_outer_repeat(right_indices, inner_size),
)
return (left_vals.view(-1, inner_size) * right_vals.view(-1, inner_size)).sum(-1)
def _get_indices(self, left_indices, right_indices):
outer_size = left_indices.size(0)
inner_size = self.root.size(-1)
After Change
return self
def _batch_get_indices(self, batch_indices, left_indices, right_indices):
n_indices = left_indices.numel()
if n_indices > self.size(-1) * self.size(-2) * self.size(-3):
return self._evaluated[batch_indices, left_indices, right_indices]
else:
outer_size = batch_indices.size(0)
inner_size = self.root.size(-1)
inner_indices = right_indices.new(inner_size)
torch.arange(0, inner_size, out=inner_indices.data)
left_vals = self.root._batch_get_indices(
_outer_repeat(batch_indices, inner_size),
_outer_repeat(left_indices, inner_size),
_inner_repeat(inner_indices, outer_size),
)
right_vals = self.root.transpose(-1, -2)._batch_get_indices(
_outer_repeat(batch_indices, inner_size),
_inner_repeat(inner_indices, outer_size),
_outer_repeat(right_indices, inner_size),
)
return (left_vals.view(-1, inner_size) * right_vals.view(-1, inner_size)).sum(-1)
def _get_indices(self, left_indices, right_indices):
n_indices = left_indices.numel()
if n_indices > self.size(-1) * self.size(-2):
return self._evaluated[left_indices, right_indices]
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 31
Instances
Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/root_lazy_variable.py
Class Name: RootLazyVariable
Method Name: _batch_get_indices
Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/matmul_lazy_variable.py
Class Name: MatmulLazyVariable
Method Name: _get_indices
Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/root_lazy_variable.py
Class Name: RootLazyVariable
Method Name: _batch_get_indices
Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/matmul_lazy_variable.py
Class Name: MatmulLazyVariable
Method Name: _batch_get_indices