1fd9746e323d5cdd4ebda131a2d95de2f31fcf58,python/baseline/pytorch/classify/train.py,,fit,#Any#Any#Any#Any#,91
Before Change
last_improved = 0
for epoch in range(epochs):
start_time = time.time()
train_metrics = trainer.train(ts)
train_duration = time.time() - start_time
print("Training time (%.3f sec)" % train_duration)
start_time = time.time()
test_metrics = trainer.test(vs)
test_duration = time.time() - start_time
print("Validation time (%.3f sec)" % test_duration)
for reporting in reporting_fns:
reporting(train_metrics, epoch, "Train")
reporting(test_metrics, epoch, "Valid")
if do_early_stopping is False:
model.save(model_file)
elif test_metrics[early_stopping_metric] > max_metric:
last_improved = epoch
max_metric = test_metrics[early_stopping_metric]
print("New max %.3f" % max_metric)
model.save(model_file)
elif (epoch - last_improved) > patience:
print("Stopping due to persistent failures to improve")
break
if do_early_stopping is True:
print("Best performance on max_metric %.3f at epoch %d" % (max_metric, last_improved))
if es is not None:
print("Reloading best checkpoint")
model = torch.load(model_file)
trainer = ClassifyTrainerPyTorch(model, **kwargs)
test_metrics = trainer.test(es)
for reporting in reporting_fns:
reporting(test_metrics, 0, "Test")
After Change
last_improved = 0
for epoch in range(epochs):
trainer.train(ts, reporting_fns)
test_metrics = trainer.test(vs, reporting_fns)
if do_early_stopping is False:
model.save(model_file)
elif test_metrics[early_stopping_metric] > max_metric:
last_improved = epoch
max_metric = test_metrics[early_stopping_metric]
print("New max %.3f" % max_metric)
model.save(model_file)
elif (epoch - last_improved) > patience:
print("Stopping due to persistent failures to improve")
break
if do_early_stopping is True:
print("Best performance on max_metric %.3f at epoch %d" % (max_metric, last_improved))
if es is not None:
print("Reloading best checkpoint")
model = torch.load(model_file)
trainer = ClassifyTrainerPyTorch(model, **kwargs)
trainer.test(es, reporting_fns, phase="Test")
In pattern: SUPERPATTERN
Frequency: 4
Non-data size: 28
Instances Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/classify/train.py
Class Name:
Method Name: fit
Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/tf/classify/train.py
Class Name:
Method Name: fit
Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/classify/train.py
Class Name:
Method Name: fit
Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/tagger/train.py
Class Name:
Method Name: fit
Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/seq2seq/train.py
Class Name:
Method Name: fit