@cached(name="svd")
def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]:
diag = self._diag_tensor.diag()
if torch.equal(diag, diag[..., :1].expand(diag.shape)):
U, S_, V = self._lazy_tensor.svd()
S = S_ + diag // this assumes all diagonal entries are positive
return U, S, V
After Change
@cached(name="svd")
def _svd(self) -> Tuple["LazyTensor", Tensor, "LazyTensor"]:
if isinstance(self._diag_tensor, ConstantDiagLazyTensor):
U, S_, V = self._lazy_tensor.svd()
S = S_ + self._diag_tensor.diag() // this assumes all diagonal entries are positive
return U, S, V