x = x.squeeze_(dim=1).squeeze_(dim=-1)
if self.num_atoms == 1 and self.out_features == 1:
// make critic outputs (B, 1) instead of (B, )
x = x.unsqueeze_(dim=1)
return x
class PolicyHead(nn.Module):
After Change
x = [z.squeeze_(dim=1).squeeze_(dim=-1) for z in x]
if self.num_atoms == 1 and self.out_features == 1:
// make critic outputs (B, 1) instead of (B, )
x = [z.unsqueeze_(dim=1)for z in x]
// B x num_heads x num_outputs x num_atoms (discrete)
// B x num_heads x num_atoms (continuous)