res = torch.sum(torch.stack(res_list), dim=0)
else:
while rhs.ndimension() < self.ndimension():
rhs = rhs.unsqueeze(0)
curr_idx = 0
res_list = []
index = [slice(None, None, None) for _ in range(self.ndimension())]
for t, size, rhs in zip(self.lazy_tensors, self.cat_dim_sizes, rhs_):
After Change
res = torch.sum(torch.stack(res_list), dim=0)
else:
output_shape = _matmul_broadcast_shape(self.shape, rhs.shape)
rhs = rhs.expand(*output_shape[:-2], *rhs.shape[-2:])
curr_idx = 0
res_list = []
for t, size in zip(self.lazy_tensors, self.cat_dim_sizes):
sub_rhs = rhs.narrow(self.cat_dim, curr_idx, size)