1fd9746e323d5cdd4ebda131a2d95de2f31fcf58,python/baseline/pytorch/classify/train.py,,fit,#Any#Any#Any#Any#,91

Before Change


    last_improved = 0

    for epoch in range(epochs):
        start_time = time.time()
        train_metrics = trainer.train(ts)
        train_duration = time.time() - start_time        
        print("Training time (%.3f sec)" % train_duration)

        start_time = time.time()
        test_metrics = trainer.test(vs)
        test_duration = time.time() - start_time
        print("Validation time (%.3f sec)" % test_duration)

        for reporting in reporting_fns:
            reporting(train_metrics, epoch, "Train")
            reporting(test_metrics, epoch, "Valid")
        
        if do_early_stopping is False:
            model.save(model_file)

        elif test_metrics[early_stopping_metric] > max_metric:
            last_improved = epoch
            max_metric = test_metrics[early_stopping_metric]
            print("New max %.3f" % max_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            print("Stopping due to persistent failures to improve")
            break
        
    if do_early_stopping is True:
        print("Best performance on max_metric %.3f at epoch %d" % (max_metric, last_improved))

    if es is not None:
        print("Reloading best checkpoint")
        model = torch.load(model_file)
        trainer = ClassifyTrainerPyTorch(model, **kwargs)
        test_metrics = trainer.test(es)
        for reporting in reporting_fns:
            reporting(test_metrics, 0, "Test")

After Change


    last_improved = 0

    for epoch in range(epochs):
        trainer.train(ts, reporting_fns)
        test_metrics = trainer.test(vs, reporting_fns)
        
        if do_early_stopping is False:
            model.save(model_file)

        elif test_metrics[early_stopping_metric] > max_metric:
            last_improved = epoch
            max_metric = test_metrics[early_stopping_metric]
            print("New max %.3f" % max_metric)
            model.save(model_file)

        elif (epoch - last_improved) > patience:
            print("Stopping due to persistent failures to improve")
            break
        
    if do_early_stopping is True:
        print("Best performance on max_metric %.3f at epoch %d" % (max_metric, last_improved))

    if es is not None:
        print("Reloading best checkpoint")
        model = torch.load(model_file)
        trainer = ClassifyTrainerPyTorch(model, **kwargs)
        trainer.test(es, reporting_fns, phase="Test")
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 4

Non-data size: 28

Instances


Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/classify/train.py
Class Name:
Method Name: fit


Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/tf/classify/train.py
Class Name:
Method Name: fit


Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/classify/train.py
Class Name:
Method Name: fit


Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/tagger/train.py
Class Name:
Method Name: fit


Project Name: dpressel/mead-baseline
Commit Name: 1fd9746e323d5cdd4ebda131a2d95de2f31fcf58
Time: 2017-07-03
Author: dpressel@gmail.com
File Name: python/baseline/pytorch/seq2seq/train.py
Class Name:
Method Name: fit