4af72126c388385371b1235d1336c1ef98723326,thumt/utils/parallel.py,,shard_features,#Any#Any#,61
Before Change
num_datashards = len(device_list)
sharded_features = {}
for k, v in six.iteritems(features):
v = tf.convert_to_tensor(v)
if not v.shape.as_list():
v = tf.expand_dims(v, axis=-1)
v = tf.tile(v, [num_datashards])
with tf.device(v.device):
sharded_features[k] = tf.split(v, num_datashards, 0)
datashard_to_features = []
for d in range(num_datashards):
feat = {
After Change
batch_size = tf.shape(v)[0]
size_splits = []
for i in range(num_datashards):
size_splits.append(
tf.cond(tf.greater(tf.mod(batch_size, num_datashards), i),
lambda: batch_size // num_datashards + 1,
lambda: batch_size // num_datashards)
)
sharded_features[k] = tf.split(v, size_splits, 0)
datashard_to_features = []
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 5
Instances
Project Name: THUNLP-MT/THUMT
Commit Name: 4af72126c388385371b1235d1336c1ef98723326
Time: 2019-04-30
Author: playinf@stu.xmu.edu.cn
File Name: thumt/utils/parallel.py
Class Name:
Method Name: shard_features
Project Name: pytorch/fairseq
Commit Name: a615533788c1842483a9708787db0d73902dc1ec
Time: 2017-09-19
Author: myleott@fb.com
File Name: fairseq/multiprocessing_trainer.py
Class Name: MultiprocessingTrainer
Method Name: _scatter_samples