db28ee240981457335c6fd9c38e542066df214cb,examples/cluster_gcn.py,,test,#,59
Before Change
logits = model(data.x, data.edge_index)
accs = []
for _, mask in data("train_mask", "val_mask", "test_mask"):
pred = logits[mask].max(1)[1]
acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
accs.append(acc)
return accs
After Change
def test():
model.eval()
total_correct, total_nodes = [0, 0, 0], [0, 0, 0]
for data in test_loader:
data = data.to(device)
logits = model(data.x, data.edge_index)
pred = logits.argmax(dim=1)
masks = [data.train_mask, data.val_mask, data.test_mask]
for i, mask in enumerate(masks):
total_correct[i] += (pred[mask] == data.y[mask]).sum().item()
total_nodes[i] += mask.sum().item()
return (torch.Tensor(total_correct) / torch.Tensor(total_nodes)).tolist()
for epoch in range(1, 31):
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 8
Instances
Project Name: rusty1s/pytorch_geometric
Commit Name: db28ee240981457335c6fd9c38e542066df214cb
Time: 2020-02-19
Author: matthias.fey@tu-dortmund.de
File Name: examples/cluster_gcn.py
Class Name:
Method Name: test
Project Name: Calamari-OCR/calamari
Commit Name: 8b5a8bedaeb5e8bc35878d7ae39430a5c285c9ec
Time: 2019-04-08
Author: wick.chr.info@gmail.com
File Name: calamari_ocr/ocr/datasets/abbyy_dataset/dataset.py
Class Name: AbbyyDataSet
Method Name: _load_sample
Project Name: merenlab/anvio
Commit Name: 2be6f9614d7a11ea88283ce1b4ad363d10c133c2
Time: 2019-08-13
Author: quentin.clayssen@gmail.com
File Name: anvio/taxoestimation.py
Class Name: SCGsTaxomy
Method Name: make_list_taxonomy