// l2 = ||(pg(z0) - g(z0))||2
prev_l2_loss = sess.run(self.prev_l2_loss, prev_feed_dict)
// pg(z0) = g(z)
self.prev_g = sess.run(self.update_prev_sample, feed_dict)
// z0 = z
self.prev_zs = zs
// optimize(l2, gl, dl)
feed_dict[self.prev_l2_loss] = prev_l2_loss
_, *metric_values = sess.run([self.optimizer] + self.output_variables(metrics), feed_dict)
if ((self.current_step % (self.config.constraint_every or 100)) == 0):
if self.config.weight_constraint:
sess.run(self.update_weight_constraints, feed_dict)
sess.run(self.assign_ema)
sess.run(self.assign_past_weights)
else:
self.hist[1]+=1
fitness_decay = config.fitness_decay or 0.99
self.min_fitness = self.min_fitness + (1.00-fitness_decay)*(fitness-self.min_fitness)
if(config.train_d_on_fitness_failure):
metric_values = sess.run([self.d_optimizer]+self.output_variables(metrics), feed_dict)[1:]
else:
metric_values = sess.run(self.output_variables(metrics), feed_dict)
self.current_step-=1
else:
if ((self.current_step % (self.config.constraint_every or 100)) == 0):
if self.config.weight_constraint:
print("Updating constraints")
sess.run(self.update_weight_constraints, feed_dict)
//standard
gl, dl, *metric_values = sess.run([self.g_loss, self.d_loss, self.optimizer] + self.output_variables(metrics), feed_dict)[1:]
sess.run(self.assign_ema)
sess.run(self.assign_past_weights)
if(gl == 0 or dl == 0):
self.steps_since_fit=0
self.mix_threshold_reached = True
print("Zero, lne?")
return
self.steps_since_fit=0
if config.g_ema_decay is not None:
feed_dict = {}
for p,pvalue in zip(self.pg_vars, prev):
feed_dict[p]=pvalue
_ = sess.run(self.g_ema, feed_dict)
if ((self.current_step % 10) == 0 and self.steps_since_fit == 0):
hist_output = " " + "".join(["G"+str(i)+":"+str(v)+" "for i, v in enumerate(self.hist)])
After Change
def required(self):
return "trainer learn_rate".split()
def _step(self, feed_dict):
gan = self.gan
sess = gan.session
config = self.config
loss = self.loss or gan.loss
metrics = loss.metrics
if self.current_step == 0 and self.steps_since_fit == 0:
sess.run(self.assign_past_weights)
sess.run(self.update_prev_sample)
if config.fitness_test is not None:
self.steps_since_fit+=1
if config.fitness_failure_threshold and self.steps_since_fit > (config.fitness_failure_threshold or 1000):
print("Fitness achieved.", self.hist[0], self.min_fitness)