ax = torch.cat([zeros, zeros, zeros, -x1, -y1, -ones, y2 * x1, y2 * y1, y2], dim=-1)
ay = torch.cat([x1, y1, ones, zeros, zeros, zeros, -x2 * x1, -x2 * y1, -x2], dim=-1)
w_list = []
axy_list = []
for i in range(points1.shape[1]):
axy_list.append(ax[:, i])
axy_list.append(ay[:, i])
w_list.append(weights[:, i])
w_list.append(weights[:, i])
A = torch.stack(axy_list, dim=1)
w = torch.stack(w_list, dim=1)
// apply weights
w_diag = torch.diag_embed(w)
A = A.transpose(-2, -1) @ w_diag @ A
After Change
else:
// We should use provided weights
assert len(weights.shape) == 2 and weights.shape == points1.shape[:2], weights.shape
w_diag = torch.diag_embed(weights.repeat(1, 2))
A = A.transpose(-2, -1) @ w_diag @ A
try:
U, S, V = torch.svd(A)