4bd09281f05185e5625112941b7ed11cf4e6cad5,catalyst/rl/agent/network.py,StateActionNet,_forward_lama,#StateActionNet#Any#Any#,127

Before Change



    def _forward_lama(self, state, action):
        state_ = state
        if len(state_.shape) < 3:
            state_ = state_.unsqueeze(1)

        if isinstance(self.observation_net, nn.Module):
            batch_size, history_len, feature_size = state_.shape
            state_ = state_.view(-1, feature_size)
            state_ = self.observation_net(state_)
            state_ = state_.view(batch_size, history_len, -1)

        state_ = self.aggregation_net(state_)

        // @TODO: add option to collapse observations based on action
        action_ = action.view(action.shape[0], -1)

After Change


        x = self.main_net(x)
        return x

    def _forward_lama(self, state, action):
        state_ = self._process_state(state, self.observation_net)
        state_ = self.aggregation_net(state_)
        action_ = self.action_net(action)
        x = torch.cat((state_, action_), dim=1)
        x = self.main_net(x)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 28

Instances


Project Name: Scitator/catalyst
Commit Name: 4bd09281f05185e5625112941b7ed11cf4e6cad5
Time: 2019-09-05
Author: scitator@gmail.com
File Name: catalyst/rl/agent/network.py
Class Name: StateActionNet
Method Name: _forward_lama


Project Name: catalyst-team/catalyst
Commit Name: 4bd09281f05185e5625112941b7ed11cf4e6cad5
Time: 2019-09-05
Author: scitator@gmail.com
File Name: catalyst/rl/agent/network.py
Class Name: StateNet
Method Name: _forward_lama


Project Name: Scitator/catalyst
Commit Name: 4bd09281f05185e5625112941b7ed11cf4e6cad5
Time: 2019-09-05
Author: scitator@gmail.com
File Name: catalyst/rl/agent/network.py
Class Name: StateActionNet
Method Name: _forward_lama