rhs = rhs.unsqueeze(0)
rhs_size = list(rhs.size())
rhs_size[0] = self.batch_size()
rhs = rhs.expand(*rhs_size)
res = self.base_lazy_tensor._matmul(rhs)
if self.sum_batch_size is not None:
res = res.view(-1, self.sum_batch_size, res.size(1), res.size(2))
res = res.sum(1)
else:
res = res.sum(0)
if isvector:
res = res.squeeze(-1)
return res
After Change
rhs = rhs.unsqueeze(0)
rhs = rhs.expand(self.base_lazy_tensor.size(0), *tuple(rhs.size())[1:])
else:
rhs = rhs.unsqueeze(1)
rhs = rhs.expand(rhs.size(0), self.num_blocks, *tuple(rhs.size())[2:])
rhs = rhs.contiguous().view(-1, rhs.size(-2), rhs.size(-1))
res = self.base_lazy_tensor._matmul(rhs)
if self.num_blocks is not None:
res = res.view(-1, self.num_blocks, res.size(1), res.size(2))
res = res.sum(-3)
if isvector:
res = res.squeeze(-1)