output_shape = self.compute_output_shape(input_shape)
self.v_thresh = k.variable(self._v_thresh)
self.mem = k.variable(self.init_membrane_potential(output_shape))
self.time = k.variable(self.dt)
// To save memory and computations, allocate only where needed:
if self.tau_refrac > 0:
self.refrac_until = k.zeros(output_shape)
if any({"spiketrains", "spikerates", "correlation", "spikecounts",
After Change
output_shape = self.compute_output_shape(input_shape)
if self.v_thresh is None:
self.v_thresh = tf.Variable(self._v_thresh, name="v_thresh",
trainable=False)
if self.mem is None:
self.mem = tf.Variable(self.init_membrane_potential(output_shape),
name="v_mem", trainable=False)
if self.time is None:
self.time = tf.Variable(self.dt, name="dt", trainable=False)
// To save memory and computations, allocate only where needed:
if self.tau_refrac > 0 and self.refrac_until is None:
self.refrac_until = tf.Variable(
tf.zeros(output_shape), name="refrac_until", trainable=False)