051bd416d3c41002f6d58b9dd71516a27243d178,thumt/launcher/ensemble_translator.py,,main,#Any#,148

Before Change


    // Write to file
    with open(args.output, "w") as outfile:
        for output in restored_outputs:
            decoded = [vocab[idx] for idx in output]
            decoded = " ".join(decoded)
            idx = decoded.find(params_list[0].eos)

            if idx >= 0:
                output = decoded[:idx].strip()
            else:
                output = decoded.strip()

            outfile.write("%s\n" % output)

After Change


        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    // Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        // Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            print("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        // Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        // Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(args.input)
        // Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)
        predictions = search.create_inference_graph(model_fns, features,
                                                    params)

        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)

        sess_creator = tf.train.ChiefSessionCreator(
            config=session_config(params)
        )

        results = []

        // Create session
        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            // Restore variables
            sess.run(assign_op)

            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)

        // Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []

        for result in results:
            outputs.append(result.tolist())

        outputs = list(itertools.chain(*outputs))

        restored_outputs = []

        for index in range(len(sorted_inputs)):
            restored_outputs.append(outputs[sorted_keys[index]])

        // Write to file
        with open(args.output, "w") as outfile:
            for output in restored_outputs:
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)
                outfile.write("%s\n" % decoded)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 13

Instances


Project Name: THUNLP-MT/THUMT
Commit Name: 051bd416d3c41002f6d58b9dd71516a27243d178
Time: 2017-11-11
Author: playinf@stu.xmu.edu.cn
File Name: thumt/launcher/ensemble_translator.py
Class Name:
Method Name: main


Project Name: shibing624/pycorrector
Commit Name: 4e144c9f842d7415d8be5bdbb5912d88ae32cced
Time: 2018-04-16
Author: 507153809@qq.com
File Name: pycorrector/seq2seq/corpus_reader.py
Class Name: CGEDReader
Method Name: read_tokens


Project Name: shibing624/pycorrector
Commit Name: 4e144c9f842d7415d8be5bdbb5912d88ae32cced
Time: 2018-04-16
Author: 507153809@qq.com
File Name: pycorrector/seq2seq/corpus_reader.py
Class Name: CGEDReader
Method Name: read_samples_by_string