class SumLazyTensor(LazyTensor):
def __init__(self, *lazy_tensors, **kwargs):
try:
lazy_tensors = tuple(lazify(lt) for lt in lazy_tensors)
except TypeError:
raise TypeError("All arguments of a SumLazyTensor should be LazyTensors or Tensors")
batch_shape = _mul_broadcast_shape(*[lt.shape for lt in lazy_tensors])
lazy_tensors = tuple(lt.expand(batch_shape) for lt in lazy_tensors)
super(SumLazyTensor, self).__init__(*lazy_tensors, **kwargs)
self.lazy_tensors = lazy_tensors