d4255c9c4d04cf7f09881b272535cfdc155957a7,agent.py,Agent,learn,#Agent#Any#,51
Before Change
// Calculate nth next state probabilities
self.online_net.reset_noise() // Sample new noise for action selection
pns = self.online_net(next_states).data // Probabilities p(s_t+n, ·; θonline)
dns = self.support.expand_as(pns) * pns // Distribution d_t+n = (z, p(s_t+n, ·; θonline))
argmax_indices_ns = dns.sum(2).max(1)[1] // Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
self.target_net.reset_noise() // Sample new target net noise
After Change
ps = self.online_net(states) // Probabilities p(s_t, ·; θonline)
ps_a = ps[range(self.batch_size), actions] // p(s_t, a_t; θonline)
with torch.no_grad():
// Calculate nth next state probabilities
// TODO: Add back below but prevent inplace operation?
// self.online_net.reset_noise() // Sample new noise for action selection
pns = self.online_net(next_states) // Probabilities p(s_t+n, ·; θonline)
dns = self.support.expand_as(pns) * pns // Distribution d_t+n = (z, p(s_t+n, ·; θonline))
argmax_indices_ns = dns.sum(2).max(1)[1] // Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
self.target_net.reset_noise() // Sample new target net noise
pns = self.target_net(next_states) // Probabilities p(s_t+n, ·; θtarget)
pns_a = pns[range(self.batch_size), argmax_indices_ns] // Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
// Compute Tz (Bellman operator T applied to z)
Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0) // Tz = R^n + (γ^n)z (accounting for terminal states)
Tz = Tz.clamp(min=self.Vmin, max=self.Vmax) // Clamp between supported values
// Compute L2 projection of Tz onto fixed support z
b = (Tz - self.Vmin) / self.delta_z // b = (Tz - Vmin) / Δz
l, u = b.floor().long(), b.ceil().long()
// Fix disappearing probability mass when l = b = u (b is int)
l[(u > 0) * (l == u)] -= 1
u[(l < (self.atoms - 1)) * (l == u)] += 1
// Distribute probability of Tz
m = states.new_zeros(self.batch_size, self.atoms)
offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1)) // m_l = m_l + p(s_t+n, a*)(u - b)
m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1)) // m_u = m_u + p(s_t+n, a*)(b - l)
ps_a = ps_a.clamp(min=1e-3) // Clamp for numerical stability in log
loss = -torch.sum(m * ps_a.log(), 1) // Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
self.online_net.zero_grad()
(weights * loss).mean().backward() // Importance weight losses
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 5
Instances
Project Name: Kaixhin/Rainbow
Commit Name: d4255c9c4d04cf7f09881b272535cfdc155957a7
Time: 2018-04-28
Author: design@kaixhin.com
File Name: agent.py
Class Name: Agent
Method Name: learn
Project Name: Kaixhin/Rainbow
Commit Name: d4255c9c4d04cf7f09881b272535cfdc155957a7
Time: 2018-04-28
Author: design@kaixhin.com
File Name: agent.py
Class Name: Agent
Method Name: evaluate_q
Project Name: Kaixhin/Rainbow
Commit Name: d4255c9c4d04cf7f09881b272535cfdc155957a7
Time: 2018-04-28
Author: design@kaixhin.com
File Name: agent.py
Class Name: Agent
Method Name: act