051bd416d3c41002f6d58b9dd71516a27243d178,thumt/launcher/translator.py,,main,#Any#,137

Before Change


    model_cls = models.get_model(args.model)
    params = default_parameters()
    params = merge_parameters(params, model_cls.get_parameters())
    params = override_parameters(params, args)

    // Build Graph
    with tf.Graph().as_default():
        // Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(params.input)
        // Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)

        // Build model
        model = model_cls(params)
        inference_fn = model.get_inference_func()
        predictions = search.create_inference_graph(inference_fn, features,
                                                    params)

        if not tf.gfile.Exists(params.path):
            raise ValueError("Path %s not found" % params.path)

        sess_creator = tf.train.ChiefSessionCreator(
            checkpoint_dir=params.path,
            config=session_config(params)
        )

        results = []

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)

        // Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []

        for result in results:
            outputs.append(result.tolist())

        outputs = list(itertools.chain(*outputs))

        restored_outputs = []

        for index in range(len(sorted_inputs)):
            restored_outputs.append(outputs[sorted_keys[index]])

        // Write to file
        with open(params.output, "w") as outfile:
            for output in restored_outputs:
                decoded = [vocab[idx] for idx in output]
                decoded = " ".join(decoded)
                idx = decoded.find(params.eos)

                if idx >= 0:
                    output = decoded[:idx].strip()
                else:
                    output = decoded.strip()

                outfile.write("%s\n" % output)


if __name__ == "__main__":

After Change


    model_cls = models.get_model(args.model)
    params = default_parameters()
    params = merge_parameters(params, model_cls.get_parameters())
    params = import_params(args.path, args.model, params)
    override_parameters(params, args)

    // Build Graph
    with tf.Graph().as_default():
        // Read input file
        sorted_keys, sorted_inputs = dataset.sort_input_file(params.input)
        // Build input queue
        features = dataset.get_inference_input(sorted_inputs, params)

        // Build model
        model = model_cls(params)
        inference_fn = model.get_inference_func()
        predictions = search.create_inference_graph(inference_fn, features,
                                                    params)

        if not tf.gfile.Exists(params.path):
            raise ValueError("Path %s not found" % params.path)

        sess_creator = tf.train.ChiefSessionCreator(
            checkpoint_dir=params.path,
            config=session_config(params)
        )

        results = []

        with tf.train.MonitoredSession(session_creator=sess_creator) as sess:
            while not sess.should_stop():
                results.append(sess.run(predictions))
                message = "Finished batch %d" % len(results)
                tf.logging.log(tf.logging.INFO, message)

        // Convert to plain text
        vocab = params.vocabulary["target"]
        outputs = []

        for result in results:
            outputs.append(result.tolist())

        outputs = list(itertools.chain(*outputs))

        restored_outputs = []

        for index in range(len(sorted_inputs)):
            restored_outputs.append(outputs[sorted_keys[index]])

        // Write to file
        with open(params.output, "w") as outfile:
            for output in restored_outputs:
                decoded = []
                for idx in output:
                    if idx == params.mapping["target"][params.eos]:
                        break
                    decoded.append(vocab[idx])

                decoded = " ".join(decoded)
                outfile.write("%s\n" % decoded)
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 8

Instances


Project Name: THUNLP-MT/THUMT
Commit Name: 051bd416d3c41002f6d58b9dd71516a27243d178
Time: 2017-11-11
Author: playinf@stu.xmu.edu.cn
File Name: thumt/launcher/translator.py
Class Name:
Method Name: main


Project Name: biocore/scikit-bio
Commit Name: a32b6ba661206b12841e0a7cf8abb16ab0782f0a
Time: 2014-08-25
Author: jai.rideout@gmail.com
File Name: skbio/io/dm.py
Class Name:
Method Name: _dm_to_matrix


Project Name: cesium-ml/cesium
Commit Name: 1cf1a4c00ba6404bb7387c722187a22357b2f193
Time: 2015-02-13
Author: a.crellinquick@gmail.com
File Name: mltsp/custom_feature_tools.py
Class Name:
Method Name: parse_tsdata_from_file