def __init__(self, base_lazy_tensor, constant):
if torch.is_tensor(constant):
if constant.ndimension() > 1:
raise RuntimeError(
"Got a constant with %d dimensions - expected a 0D or 1D tensor" % constant.ndimension()
)
elif constant.numel() > 1:
if not (base_lazy_tensor.ndimension() == 3 and base_lazy_tensor.size(0) == constant.numel()):
numel = constant.numel()
raise RuntimeError(
"A constant with size %d expedts a 3D lazy var. with batch size %d. "
"Got a %dD lazy var. with size %s"
% (numel, numel, base_lazy_tensor.ndimension(), repr(base_lazy_tensor.size()))
After Change
// Make sure that the constant can be expanded to the appropriate size
try:
constant = orig_constant.expand(base_lazy_tensor.batch_shape)
except RuntimeError:
raise RuntimeError(
"ConstantMulLazyTensor of size {} received an invalid constant of size {}.".format(
base_lazy_tensor.shape, orig_constant.shape