lower_left = fant_train_covar.matmul(L_inverse)
schur_root = torch.cholesky(fant_fant_covar - lower_left.matmul(lower_left.transpose(-2, -1)))
upper_right = torch.zeros(m, schur_root.size(-1), device=L.device, dtype=L.dtype)
// Form new root Z = [L 0; lower_left schur_root]
num_fant = schur_root.size(-2)
m, n = L.shape[-2:]
After Change
Q, R = torch.qr(new_root)
Rdiag = torch.diagonal(R, dim1=-2, dim2=-1)
// if R is almost singular, add jitter (Rdiag is a view, so this works)
zeroish = Rdiag.abs() < 1e-6
if torch.any(zeroish):
Rdiag[zeroish] = 1e-6
new_covar_cache = torch.triangular_solve(Q.transpose(-2, -1), R)[0]
except RuntimeError as e:
// TODO: Deprecate once batch QR supported in latest torch stable
if "invalid argument 1: A should be 2 dimensional" not in e.args[0]: