@torch.no_grad()
def _get_q_values(self, critic: CriticSpec, state: np.ndarray, device):
states = torch.Tensor(state).to(device).unsqueeze(0)
output = critic(states)
// We use the last head to perform actions
// This is the head corresponding to the largest gamma
if self.value_distribution == "categorical":
After Change
@torch.no_grad()
def _get_q_values(self, critic: CriticSpec, state: np.ndarray, device):
states = _state2device(state, device)
output = critic(states)
// We use the last head to perform actions
// This is the head corresponding to the largest gamma
if self.value_distribution == "categorical":