d15a15df582f43d9bc65a434db0d64c9e276b724,python/magent/builtin/mx_model/dqn.py,DeepQNetwork,infer_action,#DeepQNetwork#Any#Any#Any#Any#,136

Before Change



        n = len(view)
        self._reset_bind_size(n)
        with self.ctx:
            data_batch = mx.io.DataBatch(data=[mx.nd.array(view), mx.nd.array(feature)])
            self.model.forward(data_batch, is_train=False)
            qvalue_batch = self.model.get_outputs()[0]
            best_actions = mx.nd.argmax(qvalue_batch, axis=1).asnumpy()

        random = np.random.randint(self.num_actions, size=(n,))
        cond = np.random.uniform(0, 1, size=(n,)) < eps
        ret = np.where(cond, random, best_actions)

After Change


            eps = 0

        n = len(view)
        if n < self.num_gpu:
            view = np.tile(view, (self.num_gpu, 1, 1, 1))
            feature = np.tile(feature, (self.num_gpu, 1))

        batch_size = min(len(view), self.infer_batch_size)
        self._reset_bind_size(batch_size)
        best_actions = []
        infer_iter = mx.io.NDArrayIter(data=[view, feature], batch_size=batch_size)
        for batch in infer_iter:
            self.model.forward(batch, is_train=False)
            qvalue_batch = self.model.get_outputs()[0]
            batch_action = mx.nd.argmax(qvalue_batch, axis=1)
            best_actions.append(batch_action)
        best_actions = np.array([x.asnumpy() for x in best_actions]).flatten()
        best_actions = best_actions[:n]

        random = np.random.randint(self.num_actions, size=(n,))
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 7

Instances


Project Name: geek-ai/MAgent
Commit Name: d15a15df582f43d9bc65a434db0d64c9e276b724
Time: 2017-11-23
Author: zhenglianmin96@163.com
File Name: python/magent/builtin/mx_model/dqn.py
Class Name: DeepQNetwork
Method Name: infer_action


Project Name: facebookresearch/Horizon
Commit Name: 45499cc575d9e555fd17605104cc2452698ecbad
Time: 2018-10-30
Author: jjg@fb.com
File Name: ml/rl/test/gym/gym_predictor.py
Class Name: GymDQNPredictor
Method Name: policy


Project Name: facebookresearch/Horizon
Commit Name: 45499cc575d9e555fd17605104cc2452698ecbad
Time: 2018-10-30
Author: jjg@fb.com
File Name: ml/rl/test/gym/gym_predictor.py
Class Name: GymDQNPredictor
Method Name: predict