a57047efae5c52ae11f4211800fd299600d52e5b,examples/trpo_gym.py,,update_params,#Any#,92

Before Change


def update_params(batch):
    states = Tensor(batch.state)
    actions = ActionTensor(batch.action)
    rewards = Tensor(batch.reward)
    masks = Tensor(batch.mask)
    values = value_net(Variable(states, volatile=True)).data

    get advantage estimation from the trajectories

After Change




def update_params(batch):
    states = torch.from_numpy(np.stack(batch.state))
    actions = torch.from_numpy(np.stack(batch.action))
    rewards = torch.from_numpy(np.stack(batch.reward))
    masks = torch.from_numpy(np.stack(batch.mask).astype(np.float64))
    if use_gpu:
        states, actions, rewards, masks = states.cuda(), actions.cuda(), rewards.cuda(), masks.cuda()
    values = value_net(Variable(states, volatile=True)).data

    get advantage estimation from the trajectories
    advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, use_gpu)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 18

Instances


Project Name: lcswillems/torch-rl
Commit Name: a57047efae5c52ae11f4211800fd299600d52e5b
Time: 2017-11-28
Author: khrylx@gmail.com
File Name: examples/trpo_gym.py
Class Name:
Method Name: update_params


Project Name: lcswillems/torch-rl
Commit Name: a57047efae5c52ae11f4211800fd299600d52e5b
Time: 2017-11-28
Author: khrylx@gmail.com
File Name: examples/trpo_gym.py
Class Name:
Method Name: update_params


Project Name: lcswillems/torch-rl
Commit Name: d13e413bd79dfe516dad31fe0730d183c1239293
Time: 2017-11-28
Author: khrylx@gmail.com
File Name: examples/ppo_gym.py
Class Name:
Method Name: update_params


Project Name: lcswillems/torch-rl
Commit Name: d13e413bd79dfe516dad31fe0730d183c1239293
Time: 2017-11-28
Author: khrylx@gmail.com
File Name: examples/a2c_gym.py
Class Name:
Method Name: update_params