index = [slice(None, None, None) for _ in range(self.ndimension())]
for t, size, rhs in zip(self.lazy_tensors, self.cat_dim_sizes, rhs_):
index[self.cat_dim] = slice(curr_idx, curr_idx + size, None)
res_list.append(t._matmul(rhs[index]))
curr_idx += size
// copy result back to output device
res_list = [x.to(output_device) for x in res_list]
After Change
res_list = [x.to(output_device) for x in res_list]
res = torch.sum(torch.stack(res_list), dim=0)
else:
output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
rhs = rhs.expand(*output_shape[:-2], *rhs.shape[-2:])
curr_idx = 0
res_list = []
for t, size in zip(self.lazy_tensors, self.cat_dim_sizes):