if isinstance(self.diag_tensor, KroneckerProductDiagLazyTensor):
if all(isinstance(tdiag, ConstantDiagLazyTensor) for tdiag in dlt.lazy_tensors):
evals_p_i, evecs = _constant_kpadlt_constructor(lt, dlt)
evals_p_i_inv_root = DiagLazyTensor(evals_p_i.diag().reciprocal().sqrt())
// here we need to scale the eigenvectors by the constants as
// A = D^{1/2} Q (\kron a_i^{-1} \Lambda_i + I) Q^\top D^{1/2}
// so that we compute
// L^{-1/2} = D^{1/2} Q (\kron a_i^{-1} \Lambda_i + I)^{1/2}
// = (\kron a_i^{1/2} Q_i)(\kron a_i^{-1} \Lambda_i + I)^{-1/2}
scaled_evecs_list = []
for evec_, dlt_ in zip(evecs.lazy_tensors, dlt.lazy_tensors):
scaled_evecs_list.append(evec_ * dlt_.diag_values.sqrt())
scaled_evecs = KroneckerProductLazyTensor(*scaled_evecs_list)return MatmulLazyTensor(scaled_evecs, evals_p_i_inv_root)
// again, we compute the root decomposition by pulling across the diagonals
dlt_sqrt, evals_p_i, evecs = _symmetrize_kpadlt_constructor(lt, dlt)
dlt_inv_root = dlt_sqrt.inverse()