d4255c9c4d04cf7f09881b272535cfdc155957a7,agent.py,Agent,learn,#Agent#Any#,51
Before Change
u[(l < (self.atoms - 1)) * (l == u)] += 1
m = states.data.new(self.batch_size, self.atoms).zero_()
offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).type_as(actions)
m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))
m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))
After Change
ps = self.online_net(states)
ps_a = ps[range(self.batch_size), actions]
with torch.no_grad():
pns = self.online_net(next_states)
dns = self.support.expand_as(pns) * pns
argmax_indices_ns = dns.sum(2).max(1)[1]
self.target_net.reset_noise()
pns = self.target_net(next_states)
pns_a = pns[range(self.batch_size), argmax_indices_ns]
Tz = returns.unsqueeze(1) + nonterminals * (self.discount ** self.n) * self.support.unsqueeze(0)
Tz = Tz.clamp(min=self.Vmin, max=self.Vmax)
b = (Tz - self.Vmin) / self.delta_z
l, u = b.floor().long(), b.ceil().long()
l[(u > 0) * (l == u)] -= 1
u[(l < (self.atoms - 1)) * (l == u)] += 1
m = states.new_zeros(self.batch_size, self.atoms)
offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms), self.batch_size).unsqueeze(1).expand(self.batch_size, self.atoms).to(actions)
m.view(-1).index_add_(0, (l + offset).view(-1), (pns_a * (u.float() - b)).view(-1))
m.view(-1).index_add_(0, (u + offset).view(-1), (pns_a * (b - l.float())).view(-1))
ps_a = ps_a.clamp(min=1e-3)
loss = -torch.sum(m * ps_a.log(), 1)
self.online_net.zero_grad()
(weights * loss).mean().backward()

In pattern: SUPERPATTERN
Frequency: 4
Non-data size: 3
Instances
Project Name: Kaixhin/Rainbow
Commit Name: d4255c9c4d04cf7f09881b272535cfdc155957a7
Time: 2018-04-28
Author: design@kaixhin.com
File Name: agent.py
Class Name: Agent
Method Name: learn
Project Name: pytorch/audio
Commit Name: 3bd4db86630b75bbbfb6c5c0a1a85603097bf9b2
Time: 2019-01-04
Author: david@da3.net
File Name: torchaudio/transforms.py
Class Name: SPECTROGRAM
Method Name: __call__
Project Name: zhirongw/lemniscate.pytorch
Commit Name: 4441480fde64e42a9c4af205bf2ab8003511172e
Time: 2018-07-26
Author: xavibrowu@gmail.com
File Name: test.py
Class Name:
Method Name: kNN
Project Name: pytorch/pytorch
Commit Name: dfb7520c47290eb93b63cffad54ff9c9811a934b
Time: 2020-12-22
Author: zou3519@gmail.com
File Name: torch/testing/_internal/common_nn.py
Class Name: NewModuleTest
Method Name: _do_test