if config.g_ema_decay is not None:
decay2 = config.g_ema_decay
pg_vars = [tf.zeros_like(v) for v in g_vars]self.pg_vars = pg_vars
self.g_vars = g_vars
g_emas = [tf.assign(v, (decay2*pv+(1.0-decay2)*v)) for v, pv in zip(g_vars, pg_vars)]
self.g_ema = tf.group(g_emas)
After Change
print("DECAY", decay)
return tf.assign(v, v*(1-decay)+pastv*decay)
self.assign_ema = tf.group([_ema(a,b) for a,b in zip(allvars, self.past_weights)])
self.assign_past_weights = tf.group([tf.assign(b,a) for a,b in zip(allvars, self.past_weights)])
self.g_loss = g_loss
self.d_loss = d_loss