def resume_from_checkpoint(ckpt_path, model, optimizer=None):
print("Loading checkpoint from "{}"".format(ckpt_path))
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["state_dict"])
print("Loaded model weights")
if optimizer is not None:
optimizer.load_state_dict(ckpt["optimizer"])
After Change
def resume_from_checkpoint(fpath, model, optimizer=None):
print("Loading checkpoint from "{}"".format(fpath))
checkpoint = load_checkpoint(fpath)
model.load_state_dict(checkpoint["state_dict"])
print("Loaded model weights")
if optimizer is not None and "optimizer" in checkpoint.keys():
optimizer.load_state_dict(checkpoint["optimizer"])
print("Loaded optimizer")
start_epoch = checkpoint["epoch"]
print("Last epoch = {}".format(start_epoch))
if "rank1" in checkpoint.keys():
print("Last rank1 = {:.1%}".format(checkpoint["rank1"]))
return start_epoch