d4255c9c4d04cf7f09881b272535cfdc155957a7,agent.py,Agent,learn,#Agent#Any#,51

Before Change


    dns = self.support.expand_as(pns) * pns  // Distribution d_t+n = (z, p(s_t+n, ·; θonline))
    argmax_indices_ns = dns.sum(2).max(1)[1]  // Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
    self.target_net.reset_noise()  // Sample new target net noise
    pns = self.target_net(next_states).data  // Probabilities p(s_t+n, ·; θtarget)
    pns_a = pns[range(self.batch_size), argmax_indices_ns]  // Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

    // Compute Tz (Bellman operator T applied to z)
    Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  // Tz = R^n + (γ^n)z (accounting for terminal states)
    Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  // Clamp between supported values
    // Compute L2 projection of Tz onto fixed support z
    b = (Tz - self.Vmin) / self.delta_z  // b = (Tz - Vmin) / Δz
    l, u = b.floor().long(), b.ceil().long()
    // Fix disappearing probability mass when l = b = u (b is int)
    l[(u > 0) * (l == u)] -= 1
    u[(l < (self.atoms - 1)) * (l == u)] += 1

    // Distribute probability of Tz
    m = states.data.new(self.batch_size, self.atoms).zero_()
    offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).type_as(actions)
    m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  // m_l = m_l + p(s_t+n, a*)(u - b)
    m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  // m_u = m_u + p(s_t+n, a*)(b - l)

    ps_a = ps_a.clamp(min=1e-3)  // Clamp for numerical stability in log
    loss = -torch.sum(Variable(m) * ps_a.log(), 1)  // Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))

After Change


      dns = self.support.expand_as(pns) * pns  // Distribution d_t+n = (z, p(s_t+n, ·; θonline))
      argmax_indices_ns = dns.sum(2).max(1)[1]  // Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
      self.target_net.reset_noise()  // Sample new target net noise
      pns = self.target_net(next_states)  // Probabilities p(s_t+n, ·; θtarget)
      pns_a = pns[range(self.batch_size), argmax_indices_ns]  // Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

      // Compute Tz (Bellman operator T applied to z)
      Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)  // Tz = R^n + (γ^n)z (accounting for terminal states)
      Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)  // Clamp between supported values
      // Compute L2 projection of Tz onto fixed support z
      b = (Tz - self.Vmin) / self.delta_z  // b = (Tz - Vmin) / Δz
      l, u = b.floor().long(), b.ceil().long()
      // Fix disappearing probability mass when l = b = u (b is int)
      l[(u > 0) * (l == u)] -= 1
      u[(l < (self.atoms - 1)) * (l == u)] += 1

      // Distribute probability of Tz
      m = states.new_zeros(self.batch_size, self.atoms)
      offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
      m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))  // m_l = m_l + p(s_t+n, a*)(u - b)
      m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))  // m_u = m_u + p(s_t+n, a*)(b - l)

    ps_a = ps_a.clamp(min=1e-3)  // Clamp for numerical stability in log
    loss = -torch.sum(m * ps_a.log(), 1)  // Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 4

Instances


Project Name: Kaixhin/Rainbow
Commit Name: d4255c9c4d04cf7f09881b272535cfdc155957a7
Time: 2018-04-28
Author: design@kaixhin.com
File Name: agent.py
Class Name: Agent
Method Name: learn


Project Name: OpenNMT/OpenNMT-py
Commit Name: 26a7e209d9895000b638d2b4acacdd6c25269369
Time: 2018-02-02
Author: s.gehrmann@outlook.com
File Name: onmt/io/TextDataset.py
Class Name: TextDataset
Method Name: collapse_copy_scores


Project Name: mariogeiger/se3cnn
Commit Name: 84727d59e2d228aca53f501cf077fc70d6254afe
Time: 2019-09-25
Author: lapchevsky.k@gmail.com
File Name: se3cnn/point/operations.py
Class Name: PeriodicConvolution
Method Name: forward