N, mu, var = _get_moments(self.ops, X)
self.nr_upd += 1
alpha = self.ops.xp.asarray([0.01], dtype="float32")
// I"m not sure this is the best thing to do --
// Here we make a running estimate of the mean and variance,
// Should we consider a sample be the instance, or the batch?
diff = X - self.m
incr = (1-alpha) * diff
self.m += incr.mean(axis=0)
self.v += (diff * incr).mean(axis=0)
self.v *= alpha
Xhat = _forward(self.ops, X, mu, var)
// Batch "renormalization"
if self.nr_upd >= 7500:
Xhat *= var / (self.v+1e-08)
Xhat += (mu - self.m) / (self.v+1e-08)
y, backprop_rescale = self._begin_update_scale_shift(Xhat)
def finish_update(dy, sgd=None):
After Change
N, mu, var = _get_moments(self.ops, X)
var += self.eps
self.r = min(self.rmax, max(1. / self.rmax, var / self.v))
self.d = min(self.dmax, max(-self.dmax, (mu-self.m) / self.v))
self.nr_upd += 1
// I"m not sure this is the best thing to do --
// Should we consider a sample be the instance, or the batch?
// If we want the variance of the inputs it should be like:
"""
diff = X - self.m
incr = (1-alpha) * diff
self.m += incr.mean(axis=0)
self.v += (diff * incr).mean(axis=0)
self.v *= alpha
"""
self.m += self.alpha * (mu - self.m)
self.v += self.alpha * (var - self.v)
Xhat = _forward(self.ops, X, mu, var)
Xhat *= self.r
Xhat += self.d
y, backprop_rescale = self._begin_update_scale_shift(Xhat)
def finish_update(dy, sgd=None):