// TODO: check stability of numerics here
svd_list = self._kron_svd()
noise = self._diag_tensor[0, 0]
v_matrix = _DiagKroneckerProdLazyTensor(DiagLazyTensor(svd_list[0].S), DiagLazyTensor(svd_list[1].S))
q_matrix = KroneckerProductLazyTensor(lazify(svd_list[0].U), lazify(svd_list[1].U))
for sub_ind in range(2, len(svd_list)):
v_matrix = KroneckerProductLazyTensor(v_matrix, DiagLazyTensor(svd_list[sub_ind].S))
q_matrix = KroneckerProductLazyTensor(q_matrix, DiagLazyTensor(svd_list[sub_ind].S))
// TODO: this could be a memory hog.
inv_mat = DiagLazyTensor(v_matrix.diag() + noise)
res = q_matrix.t().matmul(tensor)
res = inv_mat.inverse().matmul(res)
res = q_matrix.matmul(res)
return tensor.t().matmul(res).squeeze()
After Change
def inv_quad(self, tensor, reduce_inv_quad=True):
// TODO: check stability of numerics here
svd_list = self._kron_svd
v_matrix = KroneckerProductLazyTensor(*[DiagLazyTensor(svd_decomp.S) for svd_decomp in svd_list])
q_matrix = KroneckerProductLazyTensor(*[lazify(svd_decomp.U) for svd_decomp in svd_list])
// TODO: this could be a memory hog.
inv_mat = DiagLazyTensor(1.0 / (v_matrix.diag() + self._diag_tensor.diag()))