851a5832065cb2988c23023b429137e1f54b5335,neural_style/neural_style.py,,train,#Any#,19
Before Change
transformer.train()
agg_content_loss = 0.
agg_style_loss = 0.
agg_tv_loss = 0.
count = 0
for batch_id, (x, _) in enumerate(train_loader):
n_batch = len(x)
count += n_batch
optimizer.zero_grad()
x = Variable(utils.preprocess_batch(x))
if args.cuda:
x = x.cuda()
// pass images through the TransformerNet
y = transformer(x)
features_y = vgg(y)
xc = Variable(x.data.clone(), volatile=True)
features_xc = vgg(xc)
f_xc_c = Variable(features_xc[1].data, requires_grad=False)
content_loss = args.content_weight * mse_loss(features_y[1], f_xc_c)
style_loss = 0.
for m in range(len(features_y)):
gram_s = Variable(gram_style[m].data, requires_grad=False)
gram_y = utils.gram_matrix(features_y[m])
style_loss += args.style_weight * mse_loss(gram_y, gram_s[:n_batch,:,:])
tv_loss = args.tv_weight * ((torch.sum(torch.abs(y[:,:,1:,:] - y[:,:,:-1,:])) + torch.sum(torch.abs(y[:,:,:,1:] - y[:,:,:,:-1]))) / float(n_batch))
total_loss = content_loss + style_loss + tv_loss
total_loss.backward()
optimizer.step()
agg_content_loss += content_loss.data[0]
agg_style_loss += style_loss.data[0]
agg_tv_loss += tv_loss.data[0]
if (batch_id + 1) % args.log_interval == 0:
mesg = "Epoch {}:\t[{}/{}]\tcontent:{:.2f}\tstyle:{:.2f}\ttv:{:.2f}".format(
e + 1, count, len(train_dataset),
After Change
kwargs = {}
print("=====================")
print("CURRENT TIME:", time.ctime())
print("PYTHON VERSION:", sys.version)
print("PYTORCH VERSION:", torch.__version__)
print("BATCH SIZE:", args.batch_size)
print("EPOCHS:", args.epochs)
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 7
Instances
Project Name: abhiskk/fast-neural-style
Commit Name: 851a5832065cb2988c23023b429137e1f54b5335
Time: 2017-04-03
Author: abhishekkadiyan@gmail.com
File Name: neural_style/neural_style.py
Class Name:
Method Name: train
Project Name: dPys/PyNets
Commit Name: 5605cfb777a9319319490c3357be491ddae88213
Time: 2018-06-13
Author: dpisner@utexas.edu
File Name: pynets/thresholding.py
Class Name:
Method Name: thresh_diff
Project Name: dPys/PyNets
Commit Name: 5605cfb777a9319319490c3357be491ddae88213
Time: 2018-06-13
Author: dpisner@utexas.edu
File Name: pynets/thresholding.py
Class Name:
Method Name: thresh_and_fit