// call the normalisation function
if is_training or use_local_stats:
with tf.control_dependencies(
[update_moving_mean, update_moving_variance]):
outputs = tf.nn.batch_normalization(
inputs, mean, variance,
beta, gamma, self.eps, name="batch_norm")
After Change
regularizer=self.regularizers["gamma"],
dtype=tf.float32, trainable=True)
collections = [tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES]
moving_mean = tf.get_variable(
"moving_mean",
shape=params_shape,
initializer=self.initializers["moving_mean"],
dtype=tf.float32, trainable=False, collections=collections)
moving_variance = tf.get_variable(
"moving_variance",
shape=params_shape,
initializer=self.initializers["moving_variance"],
dtype=tf.float32, trainable=False, collections=collections)
// mean and var
mean, variance = tf.nn.moments(inputs, axes)
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, self.moving_decay).op
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, self.moving_decay).op
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance)
// call the normalisation function
if is_training or use_local_stats:
// with tf.control_dependencies(
// [update_moving_mean, update_moving_variance]):
outputs = tf.nn.batch_normalization(
inputs, mean, variance,
beta, gamma, self.eps, name="batch_norm")
else:
outputs = tf.nn.batch_normalization(
inputs, moving_mean, moving_variance,
beta, gamma, self.eps, name="batch_norm")
outputs.set_shape(inputs.get_shape())
return outputs
// // Regularizers are not currently supported for fused batch norm.