4af72126c388385371b1235d1336c1ef98723326,thumt/utils/parallel.py,,shard_features,#Any#Any#,61

Before Change


    num_datashards = len(device_list)
    sharded_features = {}

    for k, v in six.iteritems(features):
        v = tf.convert_to_tensor(v)
        if not v.shape.as_list():
            v = tf.expand_dims(v, axis=-1)
            v = tf.tile(v, [num_datashards])
        with tf.device(v.device):
            sharded_features[k] = tf.split(v, num_datashards, 0)

    datashard_to_features = []

    for d in range(num_datashards):
        feat = {

After Change


            batch_size = tf.shape(v)[0]
            size_splits = []

            for i in range(num_datashards):
                size_splits.append(
                    tf.cond(tf.greater(tf.mod(batch_size, num_datashards), i),
                            lambda: batch_size // num_datashards + 1,
                            lambda: batch_size // num_datashards)
                )

            sharded_features[k] = tf.split(v, size_splits, 0)

    datashard_to_features = []
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 5

Instances


Project Name: THUNLP-MT/THUMT
Commit Name: 4af72126c388385371b1235d1336c1ef98723326
Time: 2019-04-30
Author: playinf@stu.xmu.edu.cn
File Name: thumt/utils/parallel.py
Class Name:
Method Name: shard_features


Project Name: pytorch/fairseq
Commit Name: a615533788c1842483a9708787db0d73902dc1ec
Time: 2017-09-19
Author: myleott@fb.com
File Name: fairseq/multiprocessing_trainer.py
Class Name: MultiprocessingTrainer
Method Name: _scatter_samples