N, mu, var = _get_moments(self.ops, X)
self.nr_upd += 1
alpha = self.ops.xp.asarray([0.01], dtype="float32")
// I"m not sure this is the best thing to do --
// Here we make a running estimate of the mean and variance,
// Should we consider a sample be the instance, or the batch?
diff = X - self.m
incr = (1-alpha) * diff
self.m += incr.mean(axis=0)
self.v += (diff * incr).mean(axis=0)
self.v *= alpha
Xhat = _forward(self.ops, X, mu, var)
// Batch "renormalization"
if self.nr_upd >= 7500:
Xhat *= var / (self.v+1e-08)
Xhat += (mu - self.m) / (self.v+1e-08)
y, backprop_rescale = self._begin_update_scale_shift(Xhat)
def finish_update(dy, sgd=None):
After Change
return y
def begin_update(self, X, drop=0.):
if drop is None:
return self.predict(X), None
assert X.dtype == "float32"
X, backprop_child = self.child.begin_update(X, drop=0.)
N, mu, var = _get_moments(self.ops, X)
var += self.eps