1e9c3ee592be5e11dcce932a73009488d6f85474,ch17/02_imag.py,,,#,78

Before Change


                tb_tracker.track("total_reward", m_reward, step_idx)
                tb_tracker.track("total_steps", m_steps, step_idx)

            obs_v = Variable(torch.from_numpy(mb_obs))
            obs_next_v = Variable(torch.from_numpy(mb_obs_next))
            actions_t = torch.LongTensor(mb_actions.tolist())
            rewards_v = Variable(torch.from_numpy(mb_rewards))
            if args.cuda:
                obs_v = obs_v.cuda()
                actions_t = actions_t.cuda()
                obs_next_v = obs_next_v.cuda()
                rewards_v = rewards_v.cuda()

            optimizer.zero_grad()
            out_obs_next_v, out_reward_v = net_em(obs_v.float()/255, actions_t)
            loss_obs_v = F.mse_loss(out_obs_next_v, obs_next_v)
            loss_rew_v = F.mse_loss(out_reward_v, rewards_v)

After Change



            optimizer.zero_grad()
            out_obs_next_v, out_reward_v = net_em(obs_v.float()/255, actions_t)
            loss_obs_v = F.mse_loss(out_obs_next_v.squeeze(-1), obs_next_v)
            loss_rew_v = F.mse_loss(out_reward_v.squeeze(-1), rewards_v)
            loss_total_v = OBS_WEIGHT * loss_obs_v + REWARD_WEIGHT * loss_rew_v
            loss_total_v.backward()
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 14

Instances


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 1e9c3ee592be5e11dcce932a73009488d6f85474
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch17/02_imag.py
Class Name:
Method Name:


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 171e9e18a10f2daea090bc6f4815db41072d66b6
Time: 2018-04-27
Author: max.lapan@gmail.com
File Name: ch10/03_pong_a2c_rollouts.py
Class Name:
Method Name:


Project Name: PacktPublishing/Deep-Reinforcement-Learning-Hands-On
Commit Name: 1e9c3ee592be5e11dcce932a73009488d6f85474
Time: 2018-04-29
Author: max.lapan@gmail.com
File Name: ch17/lib/common.py
Class Name:
Method Name: train_a2c