batch_size=batch_size,
collate_fn=collator.collate_test,
num_workers=num_workers)
dataloader_it = iter(dataloader)
// Model
model = PinSAGEModel(g, item_ntype, textset, hidden_dims, num_layers).to(device)
// Optimizer
opt = torch.optim.Adam(model.parameters(), lr=lr)
model.train()
for batch_id in range(batches_per_epoch):
pos_graph, neg_graph, blocks = next(dataloader_it)
// Copy to GPU
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
pos_graph = pos_graph.to(device)
neg_graph = neg_graph.to(device)
loss = model(pos_graph, neg_graph, blocks).mean()
opt.zero_grad()
loss.backward()
opt.step()
print("start training...")
t0 = time.time()
// For each batch of head-tail-negative triplets...
for epoch_id in range(num_epochs):
model.train()
for batch_id in range(batches_per_epoch):
pos_graph, neg_graph, blocks = next(dataloader_it)
// Copy to GPU
for i in range(len(blocks)):
blocks[i] = blocks[i].to(device)
pos_graph = pos_graph.to(device)