for k, metric in enumerate(gan.metrics.keys()):
if metric== "gradient_penalty":
print("--", gan.session.run(gan.metrics[metric]))
if math.isnan(gan.session.run(gan.metrics[metric])):
return None
tf.reset_default_graph()
gan.session.close()
return sum_metrics
After Change
if i % 300 == 0:
for k, metric in enumerate(gan.metrics.keys()):
metric_value = gan.session.run(gan.metrics[metric])
print("--", metric, metric_value)
if math.isnan(metric_value) or math.isinf(metric_value):
print("Breaking due to invalid metric")
return None