reward = kwargs["reward"]
a_plus = torch.tensor(kwargs.get("a_plus", 1.0))
a_minus = torch.tensor(kwargs.get("a_minus", -1.0))
tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
// Compute weight update based on the point eligibility value of the past timestep.
After Change
// Initialize eligibility, P^+, and P^-.
if not hasattr(self, "p_plus"):
self.p_plus = torch.zeros(self.source.n)
if not hasattr(self, "p_minus"):
self.p_minus = torch.zeros(self.target.n)
if not hasattr(self, "eligibility"):
self.eligibility = torch.zeros(*self.connection.w.shape)
// Reshape pre- and post-synaptic spikes.
source_s = self.source.s.view(-1).float()
target_s = self.target.s.view(-1).float()
// Parse keyword arguments.
reward = kwargs["reward"]
a_plus = torch.tensor(kwargs.get("a_plus", 1.0))
a_minus = torch.tensor(kwargs.get("a_minus", -1.0))
// Compute weight update based on the point eligibility value of the past timestep.
self.connection.w += self.nu[0] * reward * self.eligibility
// Update P^+ and P^- values.
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
self.p_plus += a_plus * source_s
self.p_minus *= torch.exp(-self.connection.dt / self.tc_minus)
self.p_minus += a_minus * target_s