d4255c9c4d04cf7f09881b272535cfdc155957a7,agent.py,Agent,learn,#Agent#Any#,51

Before Change


    u[(l < (self.atoms - 1)) * (l == u)] += 1

    // Distribute probability of Tz
    m = states.data.new(self.batch_size, self.atoms).zero_()
    offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).type_as(actions)
    m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  // m_l = m_l + p(s_t+n, a*)(u - b)
    m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  // m_u = m_u + p(s_t+n, a*)(b - l)

After Change


    ps = self.online_net(states)  // Probabilities p(s_t, ·; θonline)
    ps_a = ps[range(self.batch_size), actions]  // p(s_t, a_t; θonline)

    with torch.no_grad():
      // Calculate nth next state probabilities
      // TODO: Add back below but prevent inplace operation?
      // self.online_net.reset_noise()  // Sample new noise for action selection
      pns = self.online_net(next_states)  // Probabilities p(s_t+n, ·; θonline)
      dns = self.support.expand_as(pns) * pns  // Distribution d_t+n = (z, p(s_t+n, ·; θonline))
      argmax_indices_ns = dns.sum(2).max(1)[1]  // Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      self.target_net.reset_noise()  // Sample new target net noise
      pns = self.target_net(next_states)  // Probabilities p(s_t+n, ·; θtarget)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]  // Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      // Compute Tz (Bellman operator T applied to z)
      Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  // Tz = R^n + (γ^n)z (accounting for terminal states)
      Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  // Clamp between supported values
      // Compute L2 projection of Tz onto fixed support z
      b = (Tz - self.Vmin) / self.delta_z  // b = (Tz - Vmin) / Δz
      l, u = b.floor().long(), b.ceil().long()
      // Fix disappearing probability mass when l = b = u (b is int)
      l[(u > 0) * (l == u)] -= 1
      u[(l < (self.atoms - 1)) * (l == u)] += 1

      // Distribute probability of Tz
      m = states.new_zeros(self.batch_size, self.atoms)
      offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
      m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  // m_l = m_l + p(s_t+n, a*)(u - b)
      m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  // m_u = m_u + p(s_t+n, a*)(b - l)

    ps_a = ps_a.clamp(min=1e-3)  // Clamp for numerical stability in log
    loss = -torch.sum(m * ps_a.log(), 1)  // Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
    self.online_net.zero_grad()
    (weights * loss).mean().backward()  // Importance weight losses
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 3

Instances


Project Name: Kaixhin/Rainbow
Commit Name: d4255c9c4d04cf7f09881b272535cfdc155957a7
Time: 2018-04-28
Author: design@kaixhin.com
File Name: agent.py
Class Name: Agent
Method Name: learn


Project Name: pytorch/audio
Commit Name: 3bd4db86630b75bbbfb6c5c0a1a85603097bf9b2
Time: 2019-01-04
Author: david@da3.net
File Name: torchaudio/transforms.py
Class Name: SPECTROGRAM
Method Name: __call__


Project Name: zhirongw/lemniscate.pytorch
Commit Name: 4441480fde64e42a9c4af205bf2ab8003511172e
Time: 2018-07-26
Author: xavibrowu@gmail.com
File Name: test.py
Class Name:
Method Name: kNN


Project Name: pytorch/pytorch
Commit Name: dfb7520c47290eb93b63cffad54ff9c9811a934b
Time: 2020-12-22
Author: zou3519@gmail.com
File Name: torch/testing/_internal/common_nn.py
Class Name: NewModuleTest
Method Name: _do_test