303217b34070dc47a86622b62764098999b0d7f5,gpytorch/lazy/lazy_tensor.py,LazyTensor,_quad_form_derivative,#LazyTensor#Any#Any#,378
Before Change
grads = torch.autograd.grad(loss, args, allow_unused=True)
for i, arg in enumerate(args):
if toggled[i]:
arg.requires_grad = False
return grads
After Change
from collections import deque
args = tuple(self.representation())
args_with_grads = tuple(arg for arg in args if arg.requires_grad)
// Easy case: if we don"t require any gradients, then just return!
if not len(args_with_grads):
return tuple(None for _ in args)
// Normal case: we"ll use the autograd to get us a derivative
with torch.autograd.enable_grad():
loss = (left_vecs * self._matmul(right_vecs)).sum()
loss.requires_grad_(True)
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))
// Now make sure that the object we return has one entry for every item in args
grads = []
for arg in args:
if arg.requires_grad:
grads.append(actual_grads.popleft())
else:
grads.append(None)
args_with_grads = tuple(arg for arg in args if arg.requires_grad)
return grads
def _preconditioner(self):
In pattern: SUPERPATTERN
Frequency: 4
Non-data size: 10
Instances
Project Name: cornellius-gp/gpytorch
Commit Name: 303217b34070dc47a86622b62764098999b0d7f5
Time: 2018-12-12
Author: gpleiss@gmail.com
File Name: gpytorch/lazy/lazy_tensor.py
Class Name: LazyTensor
Method Name: _quad_form_derivative
Project Name: Theano/Theano
Commit Name: 289c3bd43be7cc0ca14bb505611f1f84e0e53c4a
Time: 2016-10-11
Author: gvtulder@gmail.com
File Name: theano/tensor/nnet/tests/test_abstract_conv.py
Class Name: BaseTestConv2d
Method Name: get_output_shape
Project Name: mathics/Mathics
Commit Name: 988d33a3e088759c570f143fdb2ab22c54c0f520
Time: 2016-09-12
Author: Bernhard.Liebl@gmx.org
File Name: mathics/builtin/numeric.py
Class Name: Fold
Method Name: fold
Project Name: deepgram/kur
Commit Name: 35ed48386992d824973d8ed39cfa299614b7cd34
Time: 2017-02-28
Author: ajsyp@syptech.net
File Name: kur/loggers/binary_logger.py
Class Name: BinaryLogger
Method Name: load_statistic