w_list = []
axy_list = []
for i in range(points1.shape[1]):
axy_list.append(ax[:, i])
axy_list.append(ay[:, i])
w_list.append(weights[:, i])
After Change
else:
// We should use provided weights
assert len(weights.shape) == 2 and weights.shape == points1.shape[:2], weights.shape
w_diag = torch.diag_embed(weights.repeat(1, 2))
A = A.transpose(-2, -1) @ w_diag @ A
try:
U, S, V = torch.svd(A)