value = util.to_variable(value)
value = value.view(self.length_batch, self.length_variates)
ret = (self._rate.log() * value) - self._rate - (value + 1).lgamma()
if self.length_batch == 1:
ret = ret.squeeze(0)
return ret
@property
def mean(self):
After Change
value = value.view(self.length_batch, self.length_variates)
print("valuesize", value.size())
lp = (self._rate.log() * value) - self._rate - (value + 1).lgamma()
if lp.dim() == 2 and self.length_variates > 1:
lp = util.safe_torch_sum(lp, dim=1)
if self.length_batch > 1 and self.length_variates > 1:
lp = lp.unsqueeze(1)