9ddaacb9e73ac061c08403e16ac1a4d5364dca9b,mushroom_rl/algorithms/value/dqn/rainbow.py,RainbowNetwork,forward,#RainbowNetwork#Any#Any#Any#,37

Before Change


        a_pv = self._pv(features)
        a_pa = [self._pa[i](features) for i in range(self._n_output)]
        a_pa = torch.stack(a_pa, dim=0)
        mean_a_pa = a_pa.mean(0)
        softmax = [F.softmax(a_pv + a_pa[i] - mean_a_pa, -1) for i in range(self._n_output)]
        softmax = torch.stack(softmax, dim=1)

        if not get_distribution:

After Change


        a_pv = self._pv(features)
        a_pa = [self._pa[i](features) for i in range(self._n_output)]
        a_pa = torch.stack(a_pa, dim=1)
        a_pv = a_pv.unsqueeze(1).repeat(1, self._n_output, 1)
        mean_a_pa = a_pa.mean(1, keepdim=True).repeat(1, self._n_output, 1)
        softmax = F.softmax(a_pv + a_pa - mean_a_pa, dim=-1)

        if not get_distribution:
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 2

Instances


Project Name: AIRLab-POLIMI/mushroom
Commit Name: 9ddaacb9e73ac061c08403e16ac1a4d5364dca9b
Time: 2021-02-10
Author: carlo.deramo@gmail.com
File Name: mushroom_rl/algorithms/value/dqn/rainbow.py
Class Name: RainbowNetwork
Method Name: forward


Project Name: Kaixhin/Rainbow
Commit Name: feab638ee8fc9879a61f7c1f056a4ee66be8518a
Time: 2017-11-05
Author: design@kaixhin.com
File Name: model.py
Class Name: DQN
Method Name: forward


Project Name: DagnyT/hardnet
Commit Name: 24a5450e4ca094ae4edbda26f8f29ae012721779
Time: 2017-07-25
Author: ducha.aiki@gmail.com
File Name: Losses.py
Class Name:
Method Name: loss_HardNet