a_pa = [self._pa[i](features) for i in range(self._n_output)]
a_pa = torch.stack(a_pa, dim=0)
mean_a_pa = a_pa.mean(0)
softmax = [F.softmax(a_pv + a_pa[i] - mean_a_pa, -1) for i in range(self._n_output)]
softmax = torch.stack(softmax, dim=1)
if not get_distribution:
q = torch.empty(softmax.shape[:-1])
for i in range(softmax.shape[0]):
After Change
a_pv = self._pv(features)
a_pa = [self._pa[i](features) for i in range(self._n_output)]
a_pa = torch.stack(a_pa, dim=1)
a_pv = a_pv.unsqueeze(1).repeat(1, self._n_output, 1)
mean_a_pa = a_pa.mean(1, keepdim=True).repeat(1, self._n_output, 1)
softmax = F.softmax(a_pv + a_pa - mean_a_pa, dim=-1)
if not get_distribution: