pre_mean = self.mean
// compute this batch stats
this_sum = tf.reduce_sum(x, 0)
this_bs = tf.cast(K.shape(x)[0], "float32") // this batch size
// increase count and compute weights
new_count = self.count + this_bs
alpha = this_bs/K.minimum(new_count, self.cap)
// compute new mean. Note that once we reach self.cap (e.g. 1000), the "previous mean" matters less
new_mean = pre_mean * (1-alpha) + (this_sum/this_bs) * alpha
updates = [(self.count, new_count), (self.mean, new_mean)]
self.add_update(updates, x)
After Change
self.add_update(updates, x)
// prep for broadcasting :(
p = tf.concat((K.reshape(this_bs_int, (1,)), K.shape(self.mean)), 0)
z = K.ones(p)
// the first few 1000 should not matter that much towards this cost
return K.minimum(1., new_count/self.cap) * (z * K.expand_dims(new_mean, 0))