raise RuntimeError("Input must be a single-element tensor")
val = diag.squeeze()[0]
return torch.eye(*input.size()).type_as(input).mul_(val).add_(input)
def backward(self, grad_output):
input_grad = None
After Change
raise RuntimeError("Input must be a single-element tensor")
val = diag.squeeze()[0]
diag_mat = torch.eye(input.size(-2), input.size(-1)).type_as(input)
if input.ndimension() == 3:
diag_mat = diag_mat.unsqueeze(0).expand_as(input)
return diag_mat.mul_(val).add_(input)
def backward(self, grad_output):
input_grad = None