f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b,gpytorch/lazy/root_lazy_variable.py,RootLazyVariable,_batch_get_indices,#RootLazyVariable#Any#Any#Any#,58

Before Change


        return self

    def _batch_get_indices(self, batch_indices, left_indices, right_indices):
        outer_size = batch_indices.size(0)
        inner_size = self.root.size(-1)
        inner_indices = Variable(right_indices.data.new(inner_size))
        torch.arange(0, inner_size, out=inner_indices.data)

        left_vals = self.root._batch_get_indices(
            _outer_repeat(batch_indices, inner_size),
            _outer_repeat(left_indices, inner_size),
            _inner_repeat(inner_indices, outer_size),
        )
        right_vals = self.root.transpose(-1, -2)._batch_get_indices(
            _outer_repeat(batch_indices, inner_size),
            _inner_repeat(inner_indices, outer_size),
            _outer_repeat(right_indices, inner_size),
        )

        return (left_vals.view(-1, inner_size) * right_vals.view(-1, inner_size)).sum(-1)

    def _get_indices(self, left_indices, right_indices):
        outer_size = left_indices.size(0)
        inner_size = self.root.size(-1)

After Change


        return self

    def _batch_get_indices(self, batch_indices, left_indices, right_indices):
        n_indices = left_indices.numel()
        if n_indices > self.size(-1) * self.size(-2) * self.size(-3):
            return self._evaluated[batch_indices, left_indices, right_indices]

        else:
            outer_size = batch_indices.size(0)
            inner_size = self.root.size(-1)
            inner_indices = right_indices.new(inner_size)
            torch.arange(0, inner_size, out=inner_indices.data)

            left_vals = self.root._batch_get_indices(
                _outer_repeat(batch_indices, inner_size),
                _outer_repeat(left_indices, inner_size),
                _inner_repeat(inner_indices, outer_size),
            )
            right_vals = self.root.transpose(-1, -2)._batch_get_indices(
                _outer_repeat(batch_indices, inner_size),
                _inner_repeat(inner_indices, outer_size),
                _outer_repeat(right_indices, inner_size),
            )

            return (left_vals.view(-1, inner_size) * right_vals.view(-1, inner_size)).sum(-1)

    def _get_indices(self, left_indices, right_indices):
        n_indices = left_indices.numel()
        if n_indices > self.size(-1) * self.size(-2):
            return self._evaluated[left_indices, right_indices]
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 31

Instances


Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/root_lazy_variable.py
Class Name: RootLazyVariable
Method Name: _batch_get_indices


Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/matmul_lazy_variable.py
Class Name: MatmulLazyVariable
Method Name: _get_indices


Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/root_lazy_variable.py
Class Name: RootLazyVariable
Method Name: _batch_get_indices


Project Name: cornellius-gp/gpytorch
Commit Name: f0eea66ce83fc59910c623aa4e50ddfdd1f3ae7b
Time: 2018-08-06
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/matmul_lazy_variable.py
Class Name: MatmulLazyVariable
Method Name: _batch_get_indices