)
// TODO: Figure out how to cache the decompositon for use in later solves
Lambda, S = KDinv.symeig(eigenvectors=True)
LambdaI = DiagLazyTensor(Lambda + 1)
tmp_term = S.matmul(LambdaI.inv_matmul(S._transpose_nonbatch().matmul(rhs)))
res = lt._solve(rhs - tmp_term, preconditioner=preconditioner, num_tridiag=num_tridiag)
return res.to(rhs_dtype)
After Change
// K^{-1} b - K^{-1} (S (Lambda + I)^{-1} S^T b).
// Each sub-matrix D_i^{-1} has constant diagonal, so we may scale the eigenvalues of the
// eigendecomposition of K_i by its inverse to get an eigendecomposition of K_i D_i^{-1}.
sub_evals, sub_evecs = [], []
for lt_, dlt_ in zip(lt.lazy_tensors, dlt.lazy_tensors):
evals_, evecs_ = lt_.symeig(eigenvectors=True)
sub_evals.append(DiagLazyTensor(evals_ / dlt_.diag_values))