trainer = MultiprocessingTrainer(args, model, criterion)
// Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))
// Train until the learning rate gets too small
val_loss = None
max_epoch = args.max_epoch or math.inf
After Change
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state["epoch"]
batch_offset = extra_state["batch_offset"]
print("| loaded checkpoint {} (epoch {})".format(checkpoint_path, epoch))
if batch_offset == 0:
epoch += 1