with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
bipartites = [b.to(torch.device("cuda")) for b in bipartites]
inputs = node_features[input_nodes].cuda()
labels = node_labels[output_nodes].cuda()
predictions = model(bipartites, inputs)
After Change
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
inputs = bipartites[0].srcdata["feat"]
labels = bipartites[-1].dstdata["label"]
predictions = model(bipartites, inputs)