merge_parameters(params, model_cls.get_parameters())
for params, model_cls in zip(params_list, model_cls_list)
]
params_list = [
import_params(args.checkpoints[i], args.models[i], params_list[i])
for i in range(len(args.checkpoints))
]
params_list = [
override_parameters(params_list[i], args)
for i in range(len(model_cls_list))
]
// Build Graph
with tf.Graph().as_default():
model_var_lists = []
// Load checkpoints
for i, checkpoint in enumerate(args.checkpoints):
print("Loading %s" % checkpoint)
var_list = tf.train.list_variables(checkpoint)
values = {}
reader = tf.train.load_checkpoint(checkpoint)
for (name, shape) in var_list:
if not name.startswith(model_cls_list[i].get_name()):
continue
if name.find("losses_avg") >= 0:
continue
tensor = reader.get_tensor(name)
values[name] = tensor
model_var_lists.append(values)
// Build models
model_fns = []
for i in range(len(args.checkpoints)):
name = model_cls_list[i].get_name()
model = model_cls_list[i](params_list[i], name + "_%d" % i)
model_fn = model.get_inference_func()
model_fns.append(model_fn)
params = params_list[0]
// Read input file
sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
// Build input queue
features = dataset.get_inference_input(sorted_inputs, params)
predictions = search.create_inference_graph(model_fns, features,
params)
assign_ops = []
all_var_list = tf.trainable_variables()
for i in range(len(args.checkpoints)):
un_init_var_list = []
name = model_cls_list[i].get_name()
for v in all_var_list:
if v.name.startswith(name + "_%d" % i):
un_init_var_list.append(v)
ops = set_variables(un_init_var_list, model_var_lists[i],
name + "_%d" % i)
assign_ops.extend(ops)
assign_op = tf.group(*assign_ops)
sess_creator = tf.train.ChiefSessionCreator(
config=session_config(params)
)
results = []
// Create session
with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
// Restore variables
sess.run(assign_op)
while not sess.should_stop():
results.append(sess.run(predictions))
message = "Finished batch %d" % len(results)
tf.logging.log(tf.logging.INFO, message)
// Convert to plain text
vocab = params.vocabulary["target"]
outputs = []
for result in results:
outputs.append(result.tolist())
outputs = list(itertools.chain(*outputs))
restored_outputs = []
for index in range(len(sorted_inputs)):
restored_outputs.append(outputs[sorted_keys[index]])
// Write to file
with open(args.output, "w") as outfile:
for output in restored_outputs:
decoded = []
for idx in output:
if idx == params.mapping["target"][params.eos]:
break
decoded.append(vocab[idx])
decoded = " ".join(decoded)
outfile.write("%s\n" % decoded)