02ef279b852bad53771e02435d0caa0a64d17343,baseline/tf/classify/training/distributed.py,ClassifyTrainerDistributedTf,_test,#ClassifyTrainerDistributedTf#Any#Any#,142
Before Change
strategy = self.strategy
cm = ConfusionMatrix(self.model.labels)
nc = len(self.model.labels)
def _replica_test_step(inputs):
features, y = inputs
y = tf.cast(y, tf.int64)
After Change
with strategy.scope():
total_loss = tf.Variable(0.0)
total_acc = tf.Variable(0.0)
total_norm = tf.Variable(0.0)
SET_TRAIN_FLAG(False)
test_iter = iter(loader)
for i in range(steps):
//step_loss, step_batchsz, distributed_cm = _distributed_test_step(next(test_iter))
step_loss, step_batchsz, distributed_acc = _distributed_test_step(next(test_iter))
total_loss.assign_add(step_loss)
total_norm.assign_add(step_batchsz)
total_acc.assign_add(distributed_acc)
//cm._cm += distributed_cm.numpy()
//metrics = cm.get_all_metrics()
total_loss = total_loss.numpy()
total_norm = total_norm.numpy()
total_acc = total_acc.numpy()
metrics = {}
metrics["avg_loss"] = total_loss / float(total_norm)
metrics["acc"] = total_acc / float(total_norm)
//verbose_output(verbose, cm)
return metrics
In pattern: SUPERPATTERN
Frequency: 4
Non-data size: 4
Instances
Project Name: dpressel/mead-baseline
Commit Name: 02ef279b852bad53771e02435d0caa0a64d17343
Time: 2020-09-02
Author: dpressel@gmail.com
File Name: baseline/tf/classify/training/distributed.py
Class Name: ClassifyTrainerDistributedTf
Method Name: _test
Project Name: dmlc/dgl
Commit Name: 562871e76b2d5adbcd907bcf62ca28c0611e50e1
Time: 2020-07-22
Author: zhengda1936@gmail.com
File Name: examples/pytorch/graphsage/experimental/train_dist.py
Class Name:
Method Name: main
Project Name: rusty1s/pytorch_geometric
Commit Name: 49675c507e5afa9165e378fd738a15a16f323078
Time: 2019-06-13
Author: matthias.fey@tu-dortmund.de
File Name: examples/geniepath.py
Class Name:
Method Name: test
Project Name: hunkim/PyTorchZeroToAll
Commit Name: c4610ff26a01a0622bc11dcac0f0812f05c56e0c
Time: 2017-11-02
Author: hunkim@gmail.com
File Name: 12_4_name_classify.py
Class Name: RNNClassifier
Method Name: forward