// Build input queue
src_inputs = get_dataset(params.input[0])
tgt_inputs = get_dataset(params.input[1])
inputs = [src_inputs, tgt_inputs]
features = dataset.get_training_input(inputs, params)
else:
features = record.get_input_features(
os.path.join(params.record, "*train*"), "train", params
)
// Build model
initializer = get_initializer(params)
regularizer = tf.contrib.layers.l1_l2_regularizer(
scale_l1=params.scale_l1, scale_l2=params.scale_l2)
model = model_cls(params)
// Create global step
global_step = tf.train.get_or_create_global_step()
training_func = model.get_training_func(initializer, regularizer)
loss = training_func(features)
loss = loss + tf.losses.get_regularization_loss()
// Print parameters
all_weights = {v.name: v for v in tf.trainable_variables()}
total_size = 0
for v_name in sorted(list(all_weights)):
v = all_weights[v_name]
tf.logging.info("%s\tshape %s", v.name[:-2].ljust(80),
str(v.shape).ljust(20))
v_size = np.prod(np.array(v.shape.as_list())).tolist()
total_size += v_size
tf.logging.info("Total trainable variables size: %d", total_size)
learning_rate = get_learning_rate_decay(params.learning_rate,
global_step, params)
learning_rate = tf.convert_to_tensor(learning_rate, dtype=tf.float32)
tf.summary.scalar("learning_rate", learning_rate)
// Create optimizer
if params.optimizer == "Adam":
opt = tf.train.AdamOptimizer(learning_rate,
beta1=params.adam_beta1,
beta2=params.adam_beta2,
epsilon=params.adam_epsilon)
opt = hvd.DistributedOptimizer(opt)
elif params.optimizer == "LazyAdam":
opt = tf.contrib.opt.LazyAdamOptimizer(learning_rate,
beta1=params.adam_beta1,
beta2=params.adam_beta2,
epsilon=params.adam_epsilon)
opt = hvd.DistributedOptimizer(opt)
else:
raise RuntimeError("Optimizer %s not supported" % params.optimizer)
train_op = opt.minimize(loss, global_step=global_step)
restore_op = restore_variables(args.checkpoint)
// Validation
if params.validation and params.references[0]:
After Change
override_parameters(params, args)
// Export all parameters and model specific parameters
if hvd.rank() == 0:
export_params(params.output, "params.json", params)
export_params(
params.output,