reward = kwargs["reward"]
a_plus = torch.tensor(kwargs.get("a_plus", 1.0))
a_minus = torch.tensor(kwargs.get("a_minus", -1.0))
tc_plus = torch.tensor(kwargs.get("tc_plus", 20.0))
tc_minus = torch.tensor(kwargs.get("tc_minus", 20.0))
// Compute weight update based on the point eligibility value of the past timestep.
self.connection.w += self.nu[0] * reward * self.eligibility
After Change
if not hasattr(self, "p_plus"):
self.p_plus = torch.zeros(self.source.n)
if not hasattr(self, "p_minus"):
self.p_minus = torch.zeros(self.target.n)
if not hasattr(self, "eligibility"):
self.eligibility = torch.zeros(*self.connection.w.shape)
// Reshape pre- and post-synaptic spikes.