b9b54d4f30ff65cf1c54dc0cf90c938b48c44f90,reagent/models/dqn.py,DQNBase,forward,#DQNBase#Any#,41

Before Change



    def forward(self, input) -> rlt.AllActionQValues:
        q_values = self.dist(input).q_values
        if self.quantiles > 1:
            q_values = q_values.reshape((-1, self.action_dim, self.quantiles)).mean(
                dim=2
            )
        return rlt.AllActionQValues(q_values=q_values)

    def dist(self, input: rlt.PreprocessedState) -> rlt.AllActionQValues:
        if self.feature_extractor is not None:

After Change


    def input_prototype(self):
        return rlt.PreprocessedState.from_tensor(torch.randn(1, self.state_dim))

    def forward(self, input: rlt.PreprocessedState):
        q_values = self.fc(input.state.float_features)
        return rlt.AllActionQValues(q_values=q_values)


class _DistributedDataParallelFullyConnectedDQN(ModelBase):
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 9

Instances


Project Name: facebookresearch/Horizon
Commit Name: b9b54d4f30ff65cf1c54dc0cf90c938b48c44f90
Time: 2020-04-28
Author: kittipat@fb.com
File Name: reagent/models/dqn.py
Class Name: DQNBase
Method Name: forward


Project Name: dmlc/dgl
Commit Name: 650f6ee1e0b3c2888a2c6d7db9c3d159cae5a583
Time: 2019-08-27
Author: expye@outlook.com
File Name: examples/pytorch/gin/gin.py
Class Name: GINLayer
Method Name: forward


Project Name: chainer/chainerrl
Commit Name: c5f155ad65520d229470377acb7a0014600a9388
Time: 2017-05-28
Author: kataoka@preferred.jp
File Name: chainerrl/policies/deterministic_policy.py
Class Name: ContinuousDeterministicPolicy
Method Name: __call__