lt = self.lazy_tensor
if isinstance(self.diag_tensor, KroneckerProductDiagLazyTensor):
if all(isinstance(tdiag, ConstantDiagLazyTensor) for tdiag in dlt.lazy_tensors):
sub_evals, sub_evecs = [], []
for lt_, dlt_ in zip(lt.lazy_tensors, dlt.lazy_tensors):
evals_, evecs_ = lt_.symeig(eigenvectors=True)
sub_evals.append(DiagLazyTensor(evals_ / dlt_.diag_values))
sub_evecs.append(evecs_)
After Change
if isinstance(self.diag_tensor, KroneckerProductDiagLazyTensor):
if all(isinstance(tdiag, ConstantDiagLazyTensor) for tdiag in dlt.lazy_tensors):
evals_p_i, evecs = _constant_kpadlt_constructor(lt, dlt)
evals_p_i_root = DiagLazyTensor(evals_p_i.diag().sqrt())
// here we need to scale the eigenvectors by the constants as
// A = D^{1/2} Q (\kron a_i^{-1} \Lambda_i + I) Q^\top D^{1/2}
// so that we compute
// L = D^{1/2} Q (\kron a_i^{-1} \Lambda_i + I)^{1/2}