4a89361ed8140c3ca89d02feb4ed0d850737fc85,tafl/pytorch/NNet.py,NNetWrapper,train,#NNetWrapper#Any#,41
Before Change
for epoch in range(args.epochs):
print("EPOCH ::: " + str(epoch+1))
self.nnet.train()
data_time = AverageMeter()
batch_time = AverageMeter()
pi_losses = AverageMeter()
v_losses = AverageMeter()
end = time.time()
bar = Bar("Training Net", max=int(len(examples)/args.batch_size))
batch_idx = 0
while batch_idx < int(len(examples)/args.batch_size):
sample_ids = np.random.randint(len(examples), size=args.batch_size)
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
boards = torch.FloatTensor(np.array(boards).astype(np.float64))
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))
// predict
if args.cuda:
boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
// measure data loading time
data_time.update(time.time() - end)
// compute output
out_pi, out_v = self.nnet(boards)
l_pi = self.loss_pi(target_pis, out_pi)
l_v = self.loss_v(target_vs, out_v)
total_loss = l_pi + l_v
// record loss
pi_losses.update(l_pi.item(), boards.size(0))
v_losses.update(l_v.item(), boards.size(0))
// compute gradient and do SGD step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
// measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
batch_idx += 1
// plot progress
bar.suffix = "({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_pi: {lpi:.4f} | Loss_v: {lv:.3f}".format(
batch=batch_idx,
size=int(len(examples)/args.batch_size),
data=data_time.avg,
bt=batch_time.avg,
total=bar.elapsed_td,
eta=bar.eta_td,
lpi=pi_losses.avg,
lv=v_losses.avg,
)
bar.next()
bar.finish()
def predict(self, board):
After Change
pi_losses = AverageMeter()
v_losses = AverageMeter()
batch_count = int(len(examples) / args.batch_size)
t = tqdm(range(batch_count), desc="Training Net")
for _ in t:
sample_ids = np.random.randint(len(examples), size=args.batch_size)
boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
boards = torch.FloatTensor(np.array(boards).astype(np.float64))
target_pis = torch.FloatTensor(np.array(pis))
target_vs = torch.FloatTensor(np.array(vs).astype(np.float64))
// predict
if args.cuda:
boards, target_pis, target_vs = boards.contiguous().cuda(), target_pis.contiguous().cuda(), target_vs.contiguous().cuda()
// compute output
out_pi, out_v = self.nnet(boards)
l_pi = self.loss_pi(target_pis, out_pi)
l_v = self.loss_v(target_vs, out_v)
total_loss = l_pi + l_v
// record loss
pi_losses.update(l_pi.item(), boards.size(0))
v_losses.update(l_v.item(), boards.size(0))
t.set_postfix(Loss_pi=pi_losses, Loss_v=v_losses)
// compute gradient and do SGD step
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
def predict(self, board):
board: np array with board
In pattern: SUPERPATTERN
Frequency: 6
Non-data size: 47
Instances
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: tafl/pytorch/NNet.py
Class Name: NNetWrapper
Method Name: train
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: gobang/tensorflow/NNet.py
Class Name: NNetWrapper
Method Name: train
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: othello/pytorch/NNet.py
Class Name: NNetWrapper
Method Name: train
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: othello/tensorflow/NNet.py
Class Name: NNetWrapper
Method Name: train
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: tafl/pytorch/NNet.py
Class Name: NNetWrapper
Method Name: train
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: connect4/tensorflow/NNet.py
Class Name: NNetWrapper
Method Name: train
Project Name: suragnair/alpha-zero-general
Commit Name: 4a89361ed8140c3ca89d02feb4ed0d850737fc85
Time: 2020-05-05
Author: mikhail.simin@gmail.com
File Name: othello/chainer/NNet.py
Class Name: NNetWrapper
Method Name: _train_custom_loop