from collections import deque
args = tuple(self.representation())
args_with_grads = tuple(arg for arg in args if arg.requires_grad)
// Easy case: if we don"t require any gradients, then just return!
if not len(args_with_grads):
return tuple(None for _ in args)
// Normal case: we"ll use the autograd to get us a derivative
with torch.autograd.enable_grad():
loss = (left_vecs * self._matmul(right_vecs)).sum()
loss.requires_grad_(True)
actual_grads = deque(torch.autograd.grad(loss, args_with_grads, allow_unused=True))
// Now make sure that the object we return has one entry for every item in args
grads = []
for arg in args:
if arg.requires_grad:
grads.append(actual_grads.popleft())
else:
grads.append(None)