else:
log_prior = torch.zeros_like(log_likelihood)
for _, prior, params_and_tfs, prior_tf in self.named_priors():
params = [p if ptf is None else ptf(p) for p, ptf in params_and_tfs]
lp_arg = params[0] if prior_tf is None else prior_tf(*params)
log_prior.add_(prior.log_prob(lp_arg).sum())
return log_likelihood, kl_divergence, log_prior.div(self.num_data)
After Change
else:
log_prior = torch.zeros_like(log_likelihood)
for _, prior, closure in self.named_priors():
log_prior.add_(prior.log_prob(closure()).sum())
return log_likelihood, kl_divergence, log_prior.div(self.num_data)